diff --git a/platypush/plugins/alarm/__init__.py b/platypush/plugins/alarm/__init__.py index b9202b4c..567dd622 100644 --- a/platypush/plugins/alarm/__init__.py +++ b/platypush/plugins/alarm/__init__.py @@ -3,11 +3,12 @@ import sys from threading import RLock from typing import Collection, Generator, Optional, Dict, Any, List, Union -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, make_transient from platypush.context import get_plugin from platypush.entities import EntityManager -from platypush.entities.alarm import Alarm as AlarmTable +from platypush.entities.alarm import Alarm as DbAlarm +from platypush.message.event.entities import EntityDeleteEvent from platypush.plugins import RunnablePlugin, action from platypush.plugins.db import DbPlugin from platypush.plugins.media import MediaPlugin @@ -141,7 +142,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager): with self._db_lock, self._db.get_session() as session: yield session - def _merge_alarms(self, alarms: Dict[str, AlarmTable], session: Session): + def _merge_alarms(self, alarms: Dict[str, DbAlarm], session: Session): for name, alarm in alarms.items(): if name in self.alarms: existing_alarm = self.alarms[name] @@ -154,7 +155,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager): # If the alarm record on the db is static, but the alarm is no # longer present in the configuration, then we want to delete it if alarm.static: - session.delete(alarm) + self._clear_alarm(alarm, session) else: self.alarms[name] = Alarm.from_db( alarm, @@ -165,7 +166,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager): def _sync_alarms(self): with self._get_session() as session: db_alarms = { - str(alarm.name): alarm for alarm in session.query(AlarmTable).all() + str(alarm.name): alarm for alarm in session.query(DbAlarm).all() } self._merge_alarms(db_alarms, session) @@ -178,6 +179,12 @@ class AlarmPlugin(RunnablePlugin, EntityManager): self.publish_entities(self.alarms.values()) self._synced = True + def _clear_alarm(self, alarm: DbAlarm, session: Session): + self.alarms.pop(str(alarm.name), None) + session.delete(alarm) + make_transient(alarm) + self._bus.post(EntityDeleteEvent(entity=alarm)) + def _clear_expired_alarms(self, session: Session): expired_alarms = [ alarm @@ -188,17 +195,12 @@ class AlarmPlugin(RunnablePlugin, EntityManager): if not expired_alarms: return - expired_alarm_records = session.query(AlarmTable).filter( - AlarmTable.name.in_([alarm.name for alarm in expired_alarms]) + expired_alarm_records = session.query(DbAlarm).filter( + DbAlarm.name.in_([alarm.name for alarm in expired_alarms]) ) - for alarm in expired_alarms: - self.alarms.pop(alarm.name, None) - if alarm.static: - continue - for alarm in expired_alarm_records: - session.delete(alarm) + self._clear_alarm(alarm, session) def _get_alarms(self) -> List[Alarm]: return sorted( @@ -407,7 +409,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager): self.publish_entities(self.alarms.values()) return ret - def transform_entities(self, entities: Collection[Alarm], **_) -> List[AlarmTable]: + def transform_entities(self, entities: Collection[Alarm], **_) -> List[DbAlarm]: return [alarm.to_db() for alarm in entities] def main(self): diff --git a/platypush/plugins/alarm/_model.py b/platypush/plugins/alarm/_model.py index 369e501a..63078d01 100644 --- a/platypush/plugins/alarm/_model.py +++ b/platypush/plugins/alarm/_model.py @@ -6,9 +6,10 @@ import threading from typing import Callable, Optional, Union import croniter +from dateutil.tz import gettz from platypush.context import get_bus, get_plugin -from platypush.entities.alarm import Alarm as AlarmTable +from platypush.entities.alarm import Alarm as AlarmDb from platypush.message.request import Request from platypush.message.event.alarm import ( AlarmStartedEvent, @@ -111,7 +112,10 @@ class Alarm: try: # If when is a cron expression, get the next run time - t = croniter.croniter(self.when, now).get_next() + t = croniter.croniter( + self.when, + datetime.datetime.fromtimestamp(now).replace(tzinfo=gettz()), + ).get_next() except (AttributeError, croniter.CroniterBadCronError): try: # If when is an ISO-8601 timestamp, parse it @@ -268,7 +272,7 @@ class Alarm: } @classmethod - def from_db(cls, alarm: AlarmTable, **kwargs) -> 'Alarm': + def from_db(cls, alarm: AlarmDb, **kwargs) -> 'Alarm': return cls( when=str(alarm.when), name=str(alarm.name), @@ -282,8 +286,8 @@ class Alarm: **kwargs, ) - def to_db(self) -> AlarmTable: - return AlarmTable( + def to_db(self) -> AlarmDb: + return AlarmDb( id=self.name, name=self.name, when=self.when,