Refactored concurrency model in ntfy plugin

This commit is contained in:
Fabio Manganiello 2022-08-14 00:45:29 +02:00
parent 9e2b4a0043
commit f4672ce5c3
Signed by: blacklight
GPG key ID: D90FBA7F76362774

View file

@ -7,6 +7,7 @@ from typing import Optional, Collection, Mapping
import requests import requests
import websockets import websockets
import websockets.exceptions
from platypush.context import get_bus from platypush.context import get_bus
from platypush.message.event.ntfy import NotificationEvent from platypush.message.event.ntfy import NotificationEvent
@ -48,23 +49,14 @@ class NtfyPlugin(RunnablePlugin):
] ]
) )
self._event_loop: Optional[asyncio.AbstractEventLoop] = None
self._subscriptions = subscriptions or [] self._subscriptions = subscriptions or []
self._ws_proc = None self._ws_proc = None
def _connect(self):
if self.should_stop() or (self._ws_proc and self._ws_proc.is_alive()):
self.logger.debug('Already connected')
return
self._ws_proc = multiprocessing.Process(target=self._ws_process)
self._ws_proc.start()
async def _get_ws_handler(self, url): async def _get_ws_handler(self, url):
reconnect_wait_secs = 1 reconnect_wait_secs = 1
reconnect_wait_secs_max = 60 reconnect_wait_secs_max = 60
while True: while not self.should_stop():
self.logger.debug(f'Connecting to {url}') self.logger.debug(f'Connecting to {url}')
try: try:
@ -104,30 +96,38 @@ class NtfyPlugin(RunnablePlugin):
reconnect_wait_secs * 2, reconnect_wait_secs_max reconnect_wait_secs * 2, reconnect_wait_secs_max
) )
async def _ws_processor(self, urls):
await asyncio.wait([self._get_ws_handler(url) for url in urls])
def _ws_process(self): def _ws_process(self):
self._event_loop = get_or_create_event_loop() loop = get_or_create_event_loop()
try: try:
self._event_loop.run_until_complete( loop.run_until_complete(
self._ws_processor( asyncio.wait(
{f'{self._ws_url}/{sub}/ws' for sub in self._subscriptions} {
self._get_ws_handler(f'{self._ws_url}/{sub}/ws')
for sub in self._subscriptions
}
) )
) )
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
def main(self): def main(self):
if self.should_stop() or (self._ws_proc and self._ws_proc.is_alive()):
self.logger.debug('Already connected')
return
if self._subscriptions: if self._subscriptions:
self._connect() self._ws_proc = multiprocessing.Process(target=self._ws_process)
self._ws_proc.start()
self.wait_stop() self.wait_stop()
def stop(self): def stop(self):
if self._ws_proc: if self._ws_proc and self._ws_proc.is_alive():
self._ws_proc.kill() self._ws_proc.terminate()
self._ws_proc.join() try:
self._ws_proc = None self._ws_proc.join(timeout=3)
except TimeoutError:
self._ws_proc.kill()
super().stop() super().stop()