The matrix plugin joins the AsyncRunnablePlugin family too

This commit is contained in:
Fabio Manganiello 2022-08-15 02:10:26 +02:00
parent 2797ffbe53
commit c04bc8d2bc
Signed by: blacklight
GPG key ID: D90FBA7F76362774
2 changed files with 141 additions and 113 deletions

View file

@ -1,13 +1,13 @@
import asyncio
import datetime import datetime
import json import json
import logging import logging
import multiprocessing
import os import os
import pathlib import pathlib
import re import re
from aiohttp import ClientConnectionError, ServerDisconnectedError
from dataclasses import dataclass from dataclasses import dataclass
from typing import Coroutine
from async_lru import alru_cache from async_lru import alru_cache
from nio import ( from nio import (
@ -21,6 +21,11 @@ from nio import (
InviteNameEvent, InviteNameEvent,
JoinedRoomsError, JoinedRoomsError,
KeyVerificationStart, KeyVerificationStart,
KeyVerificationEvent,
KeyVerificationAccept,
KeyVerificationMac,
KeyVerificationKey,
KeyVerificationCancel,
LoginResponse, LoginResponse,
MatrixRoom, MatrixRoom,
MegolmEvent, MegolmEvent,
@ -29,12 +34,14 @@ from nio import (
RoomGetEventError, RoomGetEventError,
RoomGetStateError, RoomGetStateError,
RoomGetStateResponse, RoomGetStateResponse,
RoomKeyRequest,
RoomMemberEvent, RoomMemberEvent,
RoomMessageText, RoomMessageText,
RoomMessageMedia, RoomMessageMedia,
RoomTopicEvent, RoomTopicEvent,
RoomUpgradeEvent, RoomUpgradeEvent,
StickerEvent, StickerEvent,
ToDeviceEvent,
UnknownEncryptedEvent, UnknownEncryptedEvent,
UnknownEvent, UnknownEvent,
) )
@ -42,7 +49,7 @@ from nio import (
from nio.client.async_client import client_session from nio.client.async_client import client_session
from platypush.config import Config from platypush.config import Config
from platypush.context import get_bus, get_or_create_event_loop from platypush.context import get_bus
from platypush.message.event.matrix import ( from platypush.message.event.matrix import (
MatrixCallAnswerEvent, MatrixCallAnswerEvent,
MatrixCallHangupEvent, MatrixCallHangupEvent,
@ -59,15 +66,14 @@ from platypush.message.event.matrix import (
MatrixStickerEvent, MatrixStickerEvent,
) )
from platypush.plugins import RunnablePlugin, action from platypush.plugins import AsyncRunnablePlugin, action
from platypush.schemas.matrix import ( from platypush.schemas.matrix import (
MatrixDeviceSchema, MatrixDeviceSchema,
MatrixEventIdSchema,
MatrixProfileSchema, MatrixProfileSchema,
MatrixRoomSchema, MatrixRoomSchema,
) )
from platypush.utils import set_thread_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -98,6 +104,7 @@ class MatrixClient(AsyncClient):
**kwargs, **kwargs,
): ):
credentials_file = os.path.abspath(os.path.expanduser(credentials_file)) credentials_file = os.path.abspath(os.path.expanduser(credentials_file))
if not store_path: if not store_path:
store_path = os.path.join(Config.get('workdir'), 'matrix', 'store') # type: ignore store_path = os.path.join(Config.get('workdir'), 'matrix', 'store') # type: ignore
if store_path: if store_path:
@ -115,6 +122,7 @@ class MatrixClient(AsyncClient):
self.logger = logging.getLogger(self.__class__.__name__) self.logger = logging.getLogger(self.__class__.__name__)
self._credentials_file = credentials_file self._credentials_file = credentials_file
self._autojoin_on_invite = autojoin_on_invite self._autojoin_on_invite = autojoin_on_invite
self._first_sync_performed = asyncio.Event()
async def _autojoin_room_callback(self, room: MatrixRoom, *_): async def _autojoin_room_callback(self, room: MatrixRoom, *_):
await self.join(room.room_id) # type: ignore await self.join(room.room_id) # type: ignore
@ -197,13 +205,15 @@ class MatrixClient(AsyncClient):
os.chmod(self._credentials_file, 0o600) os.chmod(self._credentials_file, 0o600)
self.logger.info('Synchronizing rooms') self.logger.info('Synchronizing rooms')
self._first_sync_performed.clear()
sync_token = self.loaded_sync_token sync_token = self.loaded_sync_token
self.loaded_sync_token = '' self.loaded_sync_token = ''
self._add_callbacks()
await self.sync(sync_filter={'room': {'timeline': {'limit': 1}}}) await self.sync(sync_filter={'room': {'timeline': {'limit': 1}}})
self.loaded_sync_token = sync_token self.loaded_sync_token = sync_token
self._first_sync_performed.set()
self.logger.info('Rooms synchronized') self.logger.info('Rooms synchronized')
self._add_callbacks()
return login_res return login_res
def _add_callbacks(self): def _add_callbacks(self):
@ -221,6 +231,14 @@ class MatrixClient(AsyncClient):
self.add_event_callback(self._on_unknown_encrypted_event, UnknownEncryptedEvent) # type: ignore self.add_event_callback(self._on_unknown_encrypted_event, UnknownEncryptedEvent) # type: ignore
self.add_event_callback(self._on_unknown_encrypted_event, MegolmEvent) # type: ignore self.add_event_callback(self._on_unknown_encrypted_event, MegolmEvent) # type: ignore
self.add_to_device_callback(self._on_key_verification_start, KeyVerificationStart) # type: ignore self.add_to_device_callback(self._on_key_verification_start, KeyVerificationStart) # type: ignore
self.add_to_device_callback(self._on_to_device_event, RoomKeyRequest) # type: ignore
self.add_to_device_callback(self._on_to_device_event, ToDeviceEvent)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationStart)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationKey)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationMac)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationAccept)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationCancel)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationEvent)
if self._autojoin_on_invite: if self._autojoin_on_invite:
self.add_event_callback(self._autojoin_room_callback, InviteNameEvent) # type: ignore self.add_event_callback(self._autojoin_room_callback, InviteNameEvent) # type: ignore
@ -283,6 +301,7 @@ class MatrixClient(AsyncClient):
) )
async def _on_room_message(self, room: MatrixRoom, event: RoomMessageText): async def _on_room_message(self, room: MatrixRoom, event: RoomMessageText):
if self._first_sync_performed.is_set():
get_bus().post( get_bus().post(
MatrixMessageEvent( MatrixMessageEvent(
**(await self._event_base_args(room, event)), **(await self._event_base_args(room, event)),
@ -297,7 +316,7 @@ class MatrixClient(AsyncClient):
elif event.membership == 'leave': elif event.membership == 'leave':
evt_type = MatrixRoomLeaveEvent evt_type = MatrixRoomLeaveEvent
if evt_type: if evt_type and self._first_sync_performed.is_set():
get_bus().post( get_bus().post(
evt_type( evt_type(
**(await self._event_base_args(room, event)), **(await self._event_base_args(room, event)),
@ -305,6 +324,7 @@ class MatrixClient(AsyncClient):
) )
async def _on_room_topic_changed(self, room: MatrixRoom, event: RoomTopicEvent): async def _on_room_topic_changed(self, room: MatrixRoom, event: RoomTopicEvent):
if self._first_sync_performed.is_set():
get_bus().post( get_bus().post(
MatrixRoomTopicChangedEvent( MatrixRoomTopicChangedEvent(
**(await self._event_base_args(room, event)), **(await self._event_base_args(room, event)),
@ -313,6 +333,7 @@ class MatrixClient(AsyncClient):
) )
async def _on_call_invite(self, room: MatrixRoom, event: CallInviteEvent): async def _on_call_invite(self, room: MatrixRoom, event: CallInviteEvent):
if self._first_sync_performed.is_set():
get_bus().post( get_bus().post(
MatrixCallInviteEvent( MatrixCallInviteEvent(
call_id=event.call_id, call_id=event.call_id,
@ -324,6 +345,7 @@ class MatrixClient(AsyncClient):
) )
async def _on_call_answer(self, room: MatrixRoom, event: CallAnswerEvent): async def _on_call_answer(self, room: MatrixRoom, event: CallAnswerEvent):
if self._first_sync_performed.is_set():
get_bus().post( get_bus().post(
MatrixCallAnswerEvent( MatrixCallAnswerEvent(
call_id=event.call_id, call_id=event.call_id,
@ -334,6 +356,7 @@ class MatrixClient(AsyncClient):
) )
async def _on_call_hangup(self, room: MatrixRoom, event: CallHangupEvent): async def _on_call_hangup(self, room: MatrixRoom, event: CallHangupEvent):
if self._first_sync_performed.is_set():
get_bus().post( get_bus().post(
MatrixCallHangupEvent( MatrixCallHangupEvent(
call_id=event.call_id, call_id=event.call_id,
@ -350,6 +373,7 @@ class MatrixClient(AsyncClient):
) )
async def _on_media_message(self, room: MatrixRoom, event: RoomMessageMedia): async def _on_media_message(self, room: MatrixRoom, event: RoomMessageMedia):
if self._first_sync_performed.is_set():
get_bus().post( get_bus().post(
MatrixMediaMessageEvent( MatrixMediaMessageEvent(
url=event.url, url=event.url,
@ -358,6 +382,7 @@ class MatrixClient(AsyncClient):
) )
async def _on_sticker_message(self, room: MatrixRoom, event: StickerEvent): async def _on_sticker_message(self, room: MatrixRoom, event: StickerEvent):
if self._first_sync_performed.is_set():
get_bus().post( get_bus().post(
MatrixStickerEvent( MatrixStickerEvent(
url=event.url, url=event.url,
@ -367,13 +392,14 @@ class MatrixClient(AsyncClient):
def _on_key_verification_start(self, event: KeyVerificationStart): def _on_key_verification_start(self, event: KeyVerificationStart):
assert self.olm, 'OLM state machine not initialized' assert self.olm, 'OLM state machine not initialized'
print('************ HERE')
print(event)
self.olm.handle_key_verification(event) self.olm.handle_key_verification(event)
def _on_to_device_event(self, event: ToDeviceEvent):
pass # TODO
async def _on_room_upgrade(self, room: MatrixRoom, event: RoomUpgradeEvent): async def _on_room_upgrade(self, room: MatrixRoom, event: RoomUpgradeEvent):
self.logger.info( self.logger.info(
'The room %s has been upgraded to %s', room.room_id, event.replacement_room 'The room %s has been moved to %s', room.room_id, event.replacement_room
) )
await self.room_leave(room.room_id) await self.room_leave(room.room_id)
@ -393,7 +419,7 @@ class MatrixClient(AsyncClient):
async def _on_unknown_event(self, room: MatrixRoom, event: UnknownEvent): async def _on_unknown_event(self, room: MatrixRoom, event: UnknownEvent):
evt = None evt = None
if event.type == 'm.reaction': if event.type == 'm.reaction' and self._first_sync_performed.is_set():
# Get the ID of the event this was a reaction to # Get the ID of the event this was a reaction to
relation_dict = event.source.get('content', {}).get('m.relates_to', {}) relation_dict = event.source.get('content', {}).get('m.relates_to', {})
reacted_to = relation_dict.get('event_id') reacted_to = relation_dict.get('event_id')
@ -418,7 +444,7 @@ class MatrixClient(AsyncClient):
) )
class MatrixPlugin(RunnablePlugin): class MatrixPlugin(AsyncRunnablePlugin):
""" """
Matrix chat integration. Matrix chat integration.
@ -493,16 +519,16 @@ class MatrixPlugin(RunnablePlugin):
if user_id and not re.match(user_id, '^@[a-zA-Z0-9.-_]+:.+'): if user_id and not re.match(user_id, '^@[a-zA-Z0-9.-_]+:.+'):
user_id = f'@{user_id}:{server_name}' user_id = f'@{user_id}:{server_name}'
self._matrix_proc: multiprocessing.Process | None = None # self._matrix_proc: multiprocessing.Process | None = None
self._user_id = user_id self._user_id = user_id
self._password = password self._password = password
self._access_token = access_token self._access_token = access_token
self._device_name = device_name self._device_name = device_name
self._device_id = device_id self._device_id = device_id
self._autojoin_on_invite = autojoin_on_invite self._autojoin_on_invite = autojoin_on_invite
self._event_loop = get_or_create_event_loop()
self._workdir = os.path.join(Config.get('workdir'), 'matrix') # type: ignore self._workdir = os.path.join(Config.get('workdir'), 'matrix') # type: ignore
self._credentials_file = os.path.join(self._workdir, 'credentials.json') self._credentials_file = os.path.join(self._workdir, 'credentials.json')
self._processed_responses = {}
self._client = self._get_client() self._client = self._get_client()
pathlib.Path(self._workdir).mkdir(parents=True, exist_ok=True) pathlib.Path(self._workdir).mkdir(parents=True, exist_ok=True)
@ -515,50 +541,42 @@ class MatrixPlugin(RunnablePlugin):
device_id=self._device_id, device_id=self._device_id,
) )
def _login(self) -> AsyncClient: async def _login(self) -> AsyncClient:
if not self._client: if not self._client:
self._client = self._get_client() self._client = self._get_client()
self._event_loop.run_until_complete( await self._client.login(
self._client.login(
password=self._password, password=self._password,
device_name=self._device_name, device_name=self._device_name,
token=self._access_token, token=self._access_token,
) )
)
return self._client return self._client
def _connect(self): async def listen(self):
if self.should_stop() or (self._matrix_proc and self._matrix_proc.is_alive()): while not self.should_stop():
self.logger.debug('Already connected') await self._login()
return assert self._client
self._login()
self._matrix_proc = multiprocessing.Process(target=self._run_client)
self._matrix_proc.start()
async def _run_async_client(self):
await self._client.sync_forever(timeout=0, full_state=True)
def _run_client(self):
set_thread_name('matrix-client')
while True:
try: try:
self._event_loop.run_until_complete(self._run_async_client()) await self._client.sync_forever(timeout=30000, full_state=True)
except (ClientConnectionError, ServerDisconnectedError):
self.logger.warning(
'Cannot connect to the Matrix server. Retrying in 15s'
)
self._should_stop.wait(15)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
finally: finally:
if self._client: try:
self._event_loop.run_until_complete(self._client.close()) await self._client.close()
self._matrix_proc = None finally:
self._connect() self._client = None
def _loop_execute(self, coro: Coroutine):
assert self._loop, 'The loop is not running'
ret = asyncio.run_coroutine_threadsafe(coro, self._loop).result()
if hasattr(ret, 'transport_response'):
response = ret.transport_response
assert response.ok, f'{coro} failed with status {response.status}'
return ret
@action @action
def send_message( def send_message(
@ -579,11 +597,14 @@ class MatrixPlugin(RunnablePlugin):
:param tx_id: Unique transaction ID to associate to this message. :param tx_id: Unique transaction ID to associate to this message.
:param ignore_unverified_devices: If true, unverified devices will be :param ignore_unverified_devices: If true, unverified devices will be
ignored (default: False). ignored (default: False).
:return: .. schema:: matrix.MatrixEventIdSchema
""" """
message_type = 'm.' + message_type assert self._client, 'Client not connected'
return self._event_loop.run_until_complete( assert self._loop, 'The loop is not running'
ret = self._loop_execute(
self._client.room_send( self._client.room_send(
message_type=message_type, message_type='m.' + message_type,
room_id=room_id, room_id=room_id,
tx_id=tx_id, tx_id=tx_id,
ignore_unverified_devices=ignore_unverified_devices, ignore_unverified_devices=ignore_unverified_devices,
@ -593,6 +614,12 @@ class MatrixPlugin(RunnablePlugin):
) )
) )
ret = asyncio.run_coroutine_threadsafe(
ret.transport_response.json(), self._loop
).result()
return MatrixEventIdSchema().dump(ret)
@action @action
def get_profile(self, user_id: str): def get_profile(self, user_id: str):
""" """
@ -601,8 +628,9 @@ class MatrixPlugin(RunnablePlugin):
:param user_id: User ID. :param user_id: User ID.
:return: .. schema:: matrix.MatrixProfileSchema :return: .. schema:: matrix.MatrixProfileSchema
""" """
profile = self._event_loop.run_until_complete(self._client.get_profile(user_id)) # type: ignore assert self._client, 'Client not connected'
profile.user_id = user_id # type: ignore profile = self._loop_execute(self._client.get_profile(user_id))
profile.user_id = user_id
return MatrixProfileSchema().dump(profile) return MatrixProfileSchema().dump(profile)
@action @action
@ -613,10 +641,10 @@ class MatrixPlugin(RunnablePlugin):
:param room_id: room ID. :param room_id: room ID.
:return: .. schema:: matrix.MatrixRoomSchema :return: .. schema:: matrix.MatrixRoomSchema
""" """
response = self._event_loop.run_until_complete( assert self._client, 'Client not connected'
self._client.room_get_state(room_id) response = self._loop_execute(self._client.room_get_state(room_id))
)
assert not isinstance(response, RoomGetStateError), response.message assert not isinstance(response, RoomGetStateError), response.message
room_args = {'room_id': room_id, 'own_user_id': None, 'encrypted': False} room_args = {'room_id': room_id, 'own_user_id': None, 'encrypted': False}
room_params = {} room_params = {}
@ -642,7 +670,8 @@ class MatrixPlugin(RunnablePlugin):
:return: .. schema:: matrix.MatrixDeviceSchema(many=True) :return: .. schema:: matrix.MatrixDeviceSchema(many=True)
""" """
response = self._event_loop.run_until_complete(self._client.devices()) assert self._client, 'Client not connected'
response = self._loop_execute(self._client.devices())
assert not isinstance(response, DevicesError), response.message assert not isinstance(response, DevicesError), response.message
return MatrixDeviceSchema().dump(response.devices, many=True) return MatrixDeviceSchema().dump(response.devices, many=True)
@ -651,30 +680,19 @@ class MatrixPlugin(RunnablePlugin):
""" """
Retrieve the rooms that the user has joined. Retrieve the rooms that the user has joined.
""" """
response = self._event_loop.run_until_complete(self._client.joined_rooms()) assert self._client, 'Client not connected'
response = self._loop_execute(self._client.joined_rooms())
assert not isinstance(response, JoinedRoomsError), response.message assert not isinstance(response, JoinedRoomsError), response.message
return [self.get_room(room_id).output for room_id in response.rooms] return [self.get_room(room_id).output for room_id in response.rooms] # type: ignore
@action @action
def upload_keys(self): def upload_keys(self):
""" """
Synchronize the E2EE keys with the homeserver. Synchronize the E2EE keys with the homeserver.
""" """
self._event_loop.run_until_complete(self._client.keys_upload()) assert self._client, 'Client not connected'
self._loop_execute(self._client.keys_upload())
def main(self):
self._connect()
self.wait_stop()
def stop(self):
if self._matrix_proc:
self._matrix_proc.terminate()
self._matrix_proc.join(timeout=10)
self._matrix_proc.kill()
self._matrix_proc = None
super().stop()
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -4,6 +4,16 @@ from marshmallow.schema import Schema
from platypush.schemas import DateTime from platypush.schemas import DateTime
class MatrixEventIdSchema(Schema):
event_id = fields.String(
required=True,
metadata={
'description': 'Event ID',
'example': '$24KT_aQz6sSKaZH8oTCibRTl62qywDgQXMpz5epXsW5',
},
)
class MatrixProfileSchema(Schema): class MatrixProfileSchema(Schema):
user_id = fields.String( user_id = fields.String(
required=True, required=True,