diff --git a/platypush/cron/scheduler.py b/platypush/cron/scheduler.py index 55b40a98d..b7d816fe0 100644 --- a/platypush/cron/scheduler.py +++ b/platypush/cron/scheduler.py @@ -2,13 +2,14 @@ import datetime import enum import logging import threading -from typing import Dict +import time +from typing import Dict, Optional import croniter from dateutil.tz import gettz from platypush.procedure import Procedure -from platypush.utils import is_functional_cron +from platypush.utils import get_remaining_timeout, is_functional_cron logger = logging.getLogger('platypush:cron') @@ -198,6 +199,20 @@ class CronScheduler(threading.Thread): def should_stop(self): return self._should_stop.is_set() + def wait_stop(self, timeout: Optional[float] = None): + start = time.time() + stopped = self._should_stop.wait( + timeout=get_remaining_timeout(timeout=timeout, start=start) + ) + + if not stopped: + raise TimeoutError( + f'Timeout waiting for {self.__class__.__name__} to stop.' + ) + + if threading.get_ident() != self.ident: + self.join(timeout=get_remaining_timeout(timeout=timeout, start=start)) + def run(self): logger.info('Running cron scheduler') diff --git a/platypush/entities/_engine/__init__.py b/platypush/entities/_engine/__init__.py index cdd908740..0f3a22f33 100644 --- a/platypush/entities/_engine/__init__.py +++ b/platypush/entities/_engine/__init__.py @@ -1,5 +1,6 @@ from logging import getLogger -from threading import Thread, Event +from threading import Thread, Event, get_ident +from time import time from typing import Dict, Optional from platypush.context import get_bus @@ -9,6 +10,7 @@ from platypush.message.event.entities import EntityUpdateEvent from platypush.entities._base import EntityKey, EntitySavedCallback from platypush.entities._engine.queue import EntitiesQueue from platypush.entities._engine.repo import EntitiesRepository +from platypush.utils import get_remaining_timeout class EntitiesEngine(Thread): @@ -69,6 +71,20 @@ class EntitiesEngine(Thread): def stop(self): self._should_stop.set() + def wait_stop(self, timeout: Optional[float] = None): + start = time() + stopped = self._should_stop.wait( + timeout=get_remaining_timeout(timeout=timeout, start=start) + ) + + if not stopped: + raise TimeoutError( + f'Timeout waiting for {self.__class__.__name__} to stop.' + ) + + if get_ident() != self.ident: + self.join(timeout=get_remaining_timeout(timeout=timeout, start=start)) + def notify(self, *entities: Entity): """ Trigger an EntityUpdateEvent if the entity has been persisted, or queue