[#340] Alarm integration improvements.

- Emit `EntityDeleteEvent` when an alarm is expired and removed so
  clients can properly synchronize their state.

- `croniter.get_next()` should be timezone-aware.
This commit is contained in:
Fabio Manganiello 2023-12-09 13:33:42 +01:00
parent fcb6b621ab
commit 9d5c755188
Signed by: blacklight
GPG key ID: D90FBA7F76362774
2 changed files with 25 additions and 19 deletions

View file

@ -3,11 +3,12 @@ import sys
from threading import RLock from threading import RLock
from typing import Collection, Generator, Optional, Dict, Any, List, Union from typing import Collection, Generator, Optional, Dict, Any, List, Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, make_transient
from platypush.context import get_plugin from platypush.context import get_plugin
from platypush.entities import EntityManager from platypush.entities import EntityManager
from platypush.entities.alarm import Alarm as AlarmTable from platypush.entities.alarm import Alarm as DbAlarm
from platypush.message.event.entities import EntityDeleteEvent
from platypush.plugins import RunnablePlugin, action from platypush.plugins import RunnablePlugin, action
from platypush.plugins.db import DbPlugin from platypush.plugins.db import DbPlugin
from platypush.plugins.media import MediaPlugin from platypush.plugins.media import MediaPlugin
@ -141,7 +142,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
with self._db_lock, self._db.get_session() as session: with self._db_lock, self._db.get_session() as session:
yield session yield session
def _merge_alarms(self, alarms: Dict[str, AlarmTable], session: Session): def _merge_alarms(self, alarms: Dict[str, DbAlarm], session: Session):
for name, alarm in alarms.items(): for name, alarm in alarms.items():
if name in self.alarms: if name in self.alarms:
existing_alarm = self.alarms[name] existing_alarm = self.alarms[name]
@ -154,7 +155,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
# If the alarm record on the db is static, but the alarm is no # If the alarm record on the db is static, but the alarm is no
# longer present in the configuration, then we want to delete it # longer present in the configuration, then we want to delete it
if alarm.static: if alarm.static:
session.delete(alarm) self._clear_alarm(alarm, session)
else: else:
self.alarms[name] = Alarm.from_db( self.alarms[name] = Alarm.from_db(
alarm, alarm,
@ -165,7 +166,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
def _sync_alarms(self): def _sync_alarms(self):
with self._get_session() as session: with self._get_session() as session:
db_alarms = { db_alarms = {
str(alarm.name): alarm for alarm in session.query(AlarmTable).all() str(alarm.name): alarm for alarm in session.query(DbAlarm).all()
} }
self._merge_alarms(db_alarms, session) self._merge_alarms(db_alarms, session)
@ -178,6 +179,12 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
self.publish_entities(self.alarms.values()) self.publish_entities(self.alarms.values())
self._synced = True self._synced = True
def _clear_alarm(self, alarm: DbAlarm, session: Session):
self.alarms.pop(str(alarm.name), None)
session.delete(alarm)
make_transient(alarm)
self._bus.post(EntityDeleteEvent(entity=alarm))
def _clear_expired_alarms(self, session: Session): def _clear_expired_alarms(self, session: Session):
expired_alarms = [ expired_alarms = [
alarm alarm
@ -188,17 +195,12 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
if not expired_alarms: if not expired_alarms:
return return
expired_alarm_records = session.query(AlarmTable).filter( expired_alarm_records = session.query(DbAlarm).filter(
AlarmTable.name.in_([alarm.name for alarm in expired_alarms]) DbAlarm.name.in_([alarm.name for alarm in expired_alarms])
) )
for alarm in expired_alarms:
self.alarms.pop(alarm.name, None)
if alarm.static:
continue
for alarm in expired_alarm_records: for alarm in expired_alarm_records:
session.delete(alarm) self._clear_alarm(alarm, session)
def _get_alarms(self) -> List[Alarm]: def _get_alarms(self) -> List[Alarm]:
return sorted( return sorted(
@ -407,7 +409,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
self.publish_entities(self.alarms.values()) self.publish_entities(self.alarms.values())
return ret return ret
def transform_entities(self, entities: Collection[Alarm], **_) -> List[AlarmTable]: def transform_entities(self, entities: Collection[Alarm], **_) -> List[DbAlarm]:
return [alarm.to_db() for alarm in entities] return [alarm.to_db() for alarm in entities]
def main(self): def main(self):

View file

@ -6,9 +6,10 @@ import threading
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import croniter import croniter
from dateutil.tz import gettz
from platypush.context import get_bus, get_plugin from platypush.context import get_bus, get_plugin
from platypush.entities.alarm import Alarm as AlarmTable from platypush.entities.alarm import Alarm as AlarmDb
from platypush.message.request import Request from platypush.message.request import Request
from platypush.message.event.alarm import ( from platypush.message.event.alarm import (
AlarmStartedEvent, AlarmStartedEvent,
@ -111,7 +112,10 @@ class Alarm:
try: try:
# If when is a cron expression, get the next run time # If when is a cron expression, get the next run time
t = croniter.croniter(self.when, now).get_next() t = croniter.croniter(
self.when,
datetime.datetime.fromtimestamp(now).replace(tzinfo=gettz()),
).get_next()
except (AttributeError, croniter.CroniterBadCronError): except (AttributeError, croniter.CroniterBadCronError):
try: try:
# If when is an ISO-8601 timestamp, parse it # If when is an ISO-8601 timestamp, parse it
@ -268,7 +272,7 @@ class Alarm:
} }
@classmethod @classmethod
def from_db(cls, alarm: AlarmTable, **kwargs) -> 'Alarm': def from_db(cls, alarm: AlarmDb, **kwargs) -> 'Alarm':
return cls( return cls(
when=str(alarm.when), when=str(alarm.when),
name=str(alarm.name), name=str(alarm.name),
@ -282,8 +286,8 @@ class Alarm:
**kwargs, **kwargs,
) )
def to_db(self) -> AlarmTable: def to_db(self) -> AlarmDb:
return AlarmTable( return AlarmDb(
id=self.name, id=self.name,
name=self.name, name=self.name,
when=self.when, when=self.when,