The matrix plugin joins the AsyncRunnablePlugin family too

This commit is contained in:
Fabio Manganiello 2022-08-15 02:10:26 +02:00
parent 2797ffbe53
commit c04bc8d2bc
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774
2 changed files with 141 additions and 113 deletions

View file

@ -1,13 +1,13 @@
import asyncio
import datetime
import json
import logging
import multiprocessing
import os
import pathlib
import re
from aiohttp import ClientConnectionError, ServerDisconnectedError
from dataclasses import dataclass
from typing import Coroutine
from async_lru import alru_cache
from nio import (
@ -21,6 +21,11 @@ from nio import (
InviteNameEvent,
JoinedRoomsError,
KeyVerificationStart,
KeyVerificationEvent,
KeyVerificationAccept,
KeyVerificationMac,
KeyVerificationKey,
KeyVerificationCancel,
LoginResponse,
MatrixRoom,
MegolmEvent,
@ -29,12 +34,14 @@ from nio import (
RoomGetEventError,
RoomGetStateError,
RoomGetStateResponse,
RoomKeyRequest,
RoomMemberEvent,
RoomMessageText,
RoomMessageMedia,
RoomTopicEvent,
RoomUpgradeEvent,
StickerEvent,
ToDeviceEvent,
UnknownEncryptedEvent,
UnknownEvent,
)
@ -42,7 +49,7 @@ from nio import (
from nio.client.async_client import client_session
from platypush.config import Config
from platypush.context import get_bus, get_or_create_event_loop
from platypush.context import get_bus
from platypush.message.event.matrix import (
MatrixCallAnswerEvent,
MatrixCallHangupEvent,
@ -59,15 +66,14 @@ from platypush.message.event.matrix import (
MatrixStickerEvent,
)
from platypush.plugins import RunnablePlugin, action
from platypush.plugins import AsyncRunnablePlugin, action
from platypush.schemas.matrix import (
MatrixDeviceSchema,
MatrixEventIdSchema,
MatrixProfileSchema,
MatrixRoomSchema,
)
from platypush.utils import set_thread_name
logger = logging.getLogger(__name__)
@ -98,6 +104,7 @@ class MatrixClient(AsyncClient):
**kwargs,
):
credentials_file = os.path.abspath(os.path.expanduser(credentials_file))
if not store_path:
store_path = os.path.join(Config.get('workdir'), 'matrix', 'store') # type: ignore
if store_path:
@ -115,6 +122,7 @@ class MatrixClient(AsyncClient):
self.logger = logging.getLogger(self.__class__.__name__)
self._credentials_file = credentials_file
self._autojoin_on_invite = autojoin_on_invite
self._first_sync_performed = asyncio.Event()
async def _autojoin_room_callback(self, room: MatrixRoom, *_):
await self.join(room.room_id) # type: ignore
@ -197,13 +205,15 @@ class MatrixClient(AsyncClient):
os.chmod(self._credentials_file, 0o600)
self.logger.info('Synchronizing rooms')
self._first_sync_performed.clear()
sync_token = self.loaded_sync_token
self.loaded_sync_token = ''
self._add_callbacks()
await self.sync(sync_filter={'room': {'timeline': {'limit': 1}}})
self.loaded_sync_token = sync_token
self._first_sync_performed.set()
self.logger.info('Rooms synchronized')
self._add_callbacks()
return login_res
def _add_callbacks(self):
@ -221,6 +231,14 @@ class MatrixClient(AsyncClient):
self.add_event_callback(self._on_unknown_encrypted_event, UnknownEncryptedEvent) # type: ignore
self.add_event_callback(self._on_unknown_encrypted_event, MegolmEvent) # type: ignore
self.add_to_device_callback(self._on_key_verification_start, KeyVerificationStart) # type: ignore
self.add_to_device_callback(self._on_to_device_event, RoomKeyRequest) # type: ignore
self.add_to_device_callback(self._on_to_device_event, ToDeviceEvent)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationStart)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationKey)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationMac)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationAccept)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationCancel)
self.add_to_device_callback(self._on_to_device_event, KeyVerificationEvent)
if self._autojoin_on_invite:
self.add_event_callback(self._autojoin_room_callback, InviteNameEvent) # type: ignore
@ -283,12 +301,13 @@ class MatrixClient(AsyncClient):
)
async def _on_room_message(self, room: MatrixRoom, event: RoomMessageText):
get_bus().post(
MatrixMessageEvent(
**(await self._event_base_args(room, event)),
body=event.body,
if self._first_sync_performed.is_set():
get_bus().post(
MatrixMessageEvent(
**(await self._event_base_args(room, event)),
body=event.body,
)
)
)
async def _on_room_member(self, room: MatrixRoom, event: RoomMemberEvent):
evt_type = None
@ -297,7 +316,7 @@ class MatrixClient(AsyncClient):
elif event.membership == 'leave':
evt_type = MatrixRoomLeaveEvent
if evt_type:
if evt_type and self._first_sync_performed.is_set():
get_bus().post(
evt_type(
**(await self._event_base_args(room, event)),
@ -305,42 +324,46 @@ class MatrixClient(AsyncClient):
)
async def _on_room_topic_changed(self, room: MatrixRoom, event: RoomTopicEvent):
get_bus().post(
MatrixRoomTopicChangedEvent(
**(await self._event_base_args(room, event)),
topic=event.topic,
if self._first_sync_performed.is_set():
get_bus().post(
MatrixRoomTopicChangedEvent(
**(await self._event_base_args(room, event)),
topic=event.topic,
)
)
)
async def _on_call_invite(self, room: MatrixRoom, event: CallInviteEvent):
get_bus().post(
MatrixCallInviteEvent(
call_id=event.call_id,
version=event.version,
invite_validity=event.lifetime / 1000.0,
sdp=event.offer.get('sdp'),
**(await self._event_base_args(room, event)),
if self._first_sync_performed.is_set():
get_bus().post(
MatrixCallInviteEvent(
call_id=event.call_id,
version=event.version,
invite_validity=event.lifetime / 1000.0,
sdp=event.offer.get('sdp'),
**(await self._event_base_args(room, event)),
)
)
)
async def _on_call_answer(self, room: MatrixRoom, event: CallAnswerEvent):
get_bus().post(
MatrixCallAnswerEvent(
call_id=event.call_id,
version=event.version,
sdp=event.answer.get('sdp'),
**(await self._event_base_args(room, event)),
if self._first_sync_performed.is_set():
get_bus().post(
MatrixCallAnswerEvent(
call_id=event.call_id,
version=event.version,
sdp=event.answer.get('sdp'),
**(await self._event_base_args(room, event)),
)
)
)
async def _on_call_hangup(self, room: MatrixRoom, event: CallHangupEvent):
get_bus().post(
MatrixCallHangupEvent(
call_id=event.call_id,
version=event.version,
**(await self._event_base_args(room, event)),
if self._first_sync_performed.is_set():
get_bus().post(
MatrixCallHangupEvent(
call_id=event.call_id,
version=event.version,
**(await self._event_base_args(room, event)),
)
)
)
async def _on_room_created(self, room: MatrixRoom, event: RoomCreateEvent):
get_bus().post(
@ -350,30 +373,33 @@ class MatrixClient(AsyncClient):
)
async def _on_media_message(self, room: MatrixRoom, event: RoomMessageMedia):
get_bus().post(
MatrixMediaMessageEvent(
url=event.url,
**(await self._event_base_args(room, event)),
if self._first_sync_performed.is_set():
get_bus().post(
MatrixMediaMessageEvent(
url=event.url,
**(await self._event_base_args(room, event)),
)
)
)
async def _on_sticker_message(self, room: MatrixRoom, event: StickerEvent):
get_bus().post(
MatrixStickerEvent(
url=event.url,
**(await self._event_base_args(room, event)),
if self._first_sync_performed.is_set():
get_bus().post(
MatrixStickerEvent(
url=event.url,
**(await self._event_base_args(room, event)),
)
)
)
def _on_key_verification_start(self, event: KeyVerificationStart):
assert self.olm, 'OLM state machine not initialized'
print('************ HERE')
print(event)
self.olm.handle_key_verification(event)
def _on_to_device_event(self, event: ToDeviceEvent):
pass # TODO
async def _on_room_upgrade(self, room: MatrixRoom, event: RoomUpgradeEvent):
self.logger.info(
'The room %s has been upgraded to %s', room.room_id, event.replacement_room
'The room %s has been moved to %s', room.room_id, event.replacement_room
)
await self.room_leave(room.room_id)
@ -393,7 +419,7 @@ class MatrixClient(AsyncClient):
async def _on_unknown_event(self, room: MatrixRoom, event: UnknownEvent):
evt = None
if event.type == 'm.reaction':
if event.type == 'm.reaction' and self._first_sync_performed.is_set():
# Get the ID of the event this was a reaction to
relation_dict = event.source.get('content', {}).get('m.relates_to', {})
reacted_to = relation_dict.get('event_id')
@ -418,7 +444,7 @@ class MatrixClient(AsyncClient):
)
class MatrixPlugin(RunnablePlugin):
class MatrixPlugin(AsyncRunnablePlugin):
"""
Matrix chat integration.
@ -493,16 +519,16 @@ class MatrixPlugin(RunnablePlugin):
if user_id and not re.match(user_id, '^@[a-zA-Z0-9.-_]+:.+'):
user_id = f'@{user_id}:{server_name}'
self._matrix_proc: multiprocessing.Process | None = None
# self._matrix_proc: multiprocessing.Process | None = None
self._user_id = user_id
self._password = password
self._access_token = access_token
self._device_name = device_name
self._device_id = device_id
self._autojoin_on_invite = autojoin_on_invite
self._event_loop = get_or_create_event_loop()
self._workdir = os.path.join(Config.get('workdir'), 'matrix') # type: ignore
self._credentials_file = os.path.join(self._workdir, 'credentials.json')
self._processed_responses = {}
self._client = self._get_client()
pathlib.Path(self._workdir).mkdir(parents=True, exist_ok=True)
@ -515,50 +541,42 @@ class MatrixPlugin(RunnablePlugin):
device_id=self._device_id,
)
def _login(self) -> AsyncClient:
async def _login(self) -> AsyncClient:
if not self._client:
self._client = self._get_client()
self._event_loop.run_until_complete(
self._client.login(
password=self._password,
device_name=self._device_name,
token=self._access_token,
)
await self._client.login(
password=self._password,
device_name=self._device_name,
token=self._access_token,
)
return self._client
def _connect(self):
if self.should_stop() or (self._matrix_proc and self._matrix_proc.is_alive()):
self.logger.debug('Already connected')
return
async def listen(self):
while not self.should_stop():
await self._login()
assert self._client
self._login()
self._matrix_proc = multiprocessing.Process(target=self._run_client)
self._matrix_proc.start()
async def _run_async_client(self):
await self._client.sync_forever(timeout=0, full_state=True)
def _run_client(self):
set_thread_name('matrix-client')
while True:
try:
self._event_loop.run_until_complete(self._run_async_client())
except (ClientConnectionError, ServerDisconnectedError):
self.logger.warning(
'Cannot connect to the Matrix server. Retrying in 15s'
)
self._should_stop.wait(15)
await self._client.sync_forever(timeout=30000, full_state=True)
except KeyboardInterrupt:
pass
finally:
if self._client:
self._event_loop.run_until_complete(self._client.close())
self._matrix_proc = None
self._connect()
try:
await self._client.close()
finally:
self._client = None
def _loop_execute(self, coro: Coroutine):
assert self._loop, 'The loop is not running'
ret = asyncio.run_coroutine_threadsafe(coro, self._loop).result()
if hasattr(ret, 'transport_response'):
response = ret.transport_response
assert response.ok, f'{coro} failed with status {response.status}'
return ret
@action
def send_message(
@ -579,11 +597,14 @@ class MatrixPlugin(RunnablePlugin):
:param tx_id: Unique transaction ID to associate to this message.
:param ignore_unverified_devices: If true, unverified devices will be
ignored (default: False).
:return: .. schema:: matrix.MatrixEventIdSchema
"""
message_type = 'm.' + message_type
return self._event_loop.run_until_complete(
assert self._client, 'Client not connected'
assert self._loop, 'The loop is not running'
ret = self._loop_execute(
self._client.room_send(
message_type=message_type,
message_type='m.' + message_type,
room_id=room_id,
tx_id=tx_id,
ignore_unverified_devices=ignore_unverified_devices,
@ -593,6 +614,12 @@ class MatrixPlugin(RunnablePlugin):
)
)
ret = asyncio.run_coroutine_threadsafe(
ret.transport_response.json(), self._loop
).result()
return MatrixEventIdSchema().dump(ret)
@action
def get_profile(self, user_id: str):
"""
@ -601,8 +628,9 @@ class MatrixPlugin(RunnablePlugin):
:param user_id: User ID.
:return: .. schema:: matrix.MatrixProfileSchema
"""
profile = self._event_loop.run_until_complete(self._client.get_profile(user_id)) # type: ignore
profile.user_id = user_id # type: ignore
assert self._client, 'Client not connected'
profile = self._loop_execute(self._client.get_profile(user_id))
profile.user_id = user_id
return MatrixProfileSchema().dump(profile)
@action
@ -613,10 +641,10 @@ class MatrixPlugin(RunnablePlugin):
:param room_id: room ID.
:return: .. schema:: matrix.MatrixRoomSchema
"""
response = self._event_loop.run_until_complete(
self._client.room_get_state(room_id)
)
assert self._client, 'Client not connected'
response = self._loop_execute(self._client.room_get_state(room_id))
assert not isinstance(response, RoomGetStateError), response.message
room_args = {'room_id': room_id, 'own_user_id': None, 'encrypted': False}
room_params = {}
@ -642,7 +670,8 @@ class MatrixPlugin(RunnablePlugin):
:return: .. schema:: matrix.MatrixDeviceSchema(many=True)
"""
response = self._event_loop.run_until_complete(self._client.devices())
assert self._client, 'Client not connected'
response = self._loop_execute(self._client.devices())
assert not isinstance(response, DevicesError), response.message
return MatrixDeviceSchema().dump(response.devices, many=True)
@ -651,30 +680,19 @@ class MatrixPlugin(RunnablePlugin):
"""
Retrieve the rooms that the user has joined.
"""
response = self._event_loop.run_until_complete(self._client.joined_rooms())
assert self._client, 'Client not connected'
response = self._loop_execute(self._client.joined_rooms())
assert not isinstance(response, JoinedRoomsError), response.message
return [self.get_room(room_id).output for room_id in response.rooms]
return [self.get_room(room_id).output for room_id in response.rooms] # type: ignore
@action
def upload_keys(self):
"""
Synchronize the E2EE keys with the homeserver.
"""
self._event_loop.run_until_complete(self._client.keys_upload())
def main(self):
self._connect()
self.wait_stop()
def stop(self):
if self._matrix_proc:
self._matrix_proc.terminate()
self._matrix_proc.join(timeout=10)
self._matrix_proc.kill()
self._matrix_proc = None
super().stop()
assert self._client, 'Client not connected'
self._loop_execute(self._client.keys_upload())
# vim:sw=4:ts=4:et:

View file

@ -4,6 +4,16 @@ from marshmallow.schema import Schema
from platypush.schemas import DateTime
class MatrixEventIdSchema(Schema):
event_id = fields.String(
required=True,
metadata={
'description': 'Event ID',
'example': '$24KT_aQz6sSKaZH8oTCibRTl62qywDgQXMpz5epXsW5',
},
)
class MatrixProfileSchema(Schema):
user_id = fields.String(
required=True,