From 4043878afd0803f1553c21238f18baf62e9623ce Mon Sep 17 00:00:00 2001
From: Fabio Manganiello <info@fabiomanganiello.com>
Date: Sun, 14 Aug 2022 00:45:29 +0200
Subject: [PATCH] Refactored concurrency model in ntfy plugin

---
 platypush/plugins/ntfy/__init__.py | 48 ++++++++++++++----------------
 1 file changed, 23 insertions(+), 25 deletions(-)

diff --git a/platypush/plugins/ntfy/__init__.py b/platypush/plugins/ntfy/__init__.py
index 06427572..8619df7a 100644
--- a/platypush/plugins/ntfy/__init__.py
+++ b/platypush/plugins/ntfy/__init__.py
@@ -7,6 +7,7 @@ from typing import Optional, Collection, Mapping
 
 import requests
 import websockets
+import websockets.exceptions
 
 from platypush.context import get_bus
 from platypush.message.event.ntfy import NotificationEvent
@@ -48,23 +49,14 @@ class NtfyPlugin(RunnablePlugin):
             ]
         )
 
-        self._event_loop: Optional[asyncio.AbstractEventLoop] = None
         self._subscriptions = subscriptions or []
         self._ws_proc = None
 
-    def _connect(self):
-        if self.should_stop() or (self._ws_proc and self._ws_proc.is_alive()):
-            self.logger.debug('Already connected')
-            return
-
-        self._ws_proc = multiprocessing.Process(target=self._ws_process)
-        self._ws_proc.start()
-
     async def _get_ws_handler(self, url):
         reconnect_wait_secs = 1
         reconnect_wait_secs_max = 60
 
-        while True:
+        while not self.should_stop():
             self.logger.debug(f'Connecting to {url}')
 
             try:
@@ -104,32 +96,38 @@ class NtfyPlugin(RunnablePlugin):
                     reconnect_wait_secs * 2, reconnect_wait_secs_max
                 )
 
-    async def _ws_processor(self, urls):
-        await asyncio.wait([self._get_ws_handler(url) for url in urls])
-
     def _ws_process(self):
-        self._event_loop = get_or_create_event_loop()
+        loop = get_or_create_event_loop()
         try:
-            self._event_loop.run_until_complete(
-                self._ws_processor(
-                    {f'{self._ws_url}/{sub}/ws' for sub in self._subscriptions}
+            loop.run_until_complete(
+                asyncio.wait(
+                    {
+                        self._get_ws_handler(f'{self._ws_url}/{sub}/ws')
+                        for sub in self._subscriptions
+                    }
                 )
             )
         except KeyboardInterrupt:
             pass
 
     def main(self):
-        if self._subscriptions:
-            self._connect()
+        if self.should_stop() or (self._ws_proc and self._ws_proc.is_alive()):
+            self.logger.debug('Already connected')
+            return
 
-        while not self._should_stop.is_set():
-            self._should_stop.wait(timeout=1)
+        if self._subscriptions:
+            self._ws_proc = multiprocessing.Process(target=self._ws_process)
+            self._ws_proc.start()
+
+        self.wait_stop()
 
     def stop(self):
-        if self._ws_proc:
-            self._ws_proc.kill()
-            self._ws_proc.join()
-            self._ws_proc = None
+        if self._ws_proc and self._ws_proc.is_alive():
+            self._ws_proc.terminate()
+            try:
+                self._ws_proc.join(timeout=3)
+            except TimeoutError:
+                self._ws_proc.kill()
 
         super().stop()