[#340] Alarm integration improvements.

- Emit `EntityDeleteEvent` when an alarm is expired and removed so
  clients can properly synchronize their state.

- `croniter.get_next()` should be timezone-aware.
This commit is contained in:
Fabio Manganiello 2023-12-09 13:33:42 +01:00
parent fcb6b621ab
commit 9d5c755188
Signed by: blacklight
GPG key ID: D90FBA7F76362774
2 changed files with 25 additions and 19 deletions

View file

@ -3,11 +3,12 @@ import sys
from threading import RLock
from typing import Collection, Generator, Optional, Dict, Any, List, Union
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, make_transient
from platypush.context import get_plugin
from platypush.entities import EntityManager
from platypush.entities.alarm import Alarm as AlarmTable
from platypush.entities.alarm import Alarm as DbAlarm
from platypush.message.event.entities import EntityDeleteEvent
from platypush.plugins import RunnablePlugin, action
from platypush.plugins.db import DbPlugin
from platypush.plugins.media import MediaPlugin
@ -141,7 +142,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
with self._db_lock, self._db.get_session() as session:
yield session
def _merge_alarms(self, alarms: Dict[str, AlarmTable], session: Session):
def _merge_alarms(self, alarms: Dict[str, DbAlarm], session: Session):
for name, alarm in alarms.items():
if name in self.alarms:
existing_alarm = self.alarms[name]
@ -154,7 +155,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
# If the alarm record on the db is static, but the alarm is no
# longer present in the configuration, then we want to delete it
if alarm.static:
session.delete(alarm)
self._clear_alarm(alarm, session)
else:
self.alarms[name] = Alarm.from_db(
alarm,
@ -165,7 +166,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
def _sync_alarms(self):
with self._get_session() as session:
db_alarms = {
str(alarm.name): alarm for alarm in session.query(AlarmTable).all()
str(alarm.name): alarm for alarm in session.query(DbAlarm).all()
}
self._merge_alarms(db_alarms, session)
@ -178,6 +179,12 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
self.publish_entities(self.alarms.values())
self._synced = True
def _clear_alarm(self, alarm: DbAlarm, session: Session):
self.alarms.pop(str(alarm.name), None)
session.delete(alarm)
make_transient(alarm)
self._bus.post(EntityDeleteEvent(entity=alarm))
def _clear_expired_alarms(self, session: Session):
expired_alarms = [
alarm
@ -188,17 +195,12 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
if not expired_alarms:
return
expired_alarm_records = session.query(AlarmTable).filter(
AlarmTable.name.in_([alarm.name for alarm in expired_alarms])
expired_alarm_records = session.query(DbAlarm).filter(
DbAlarm.name.in_([alarm.name for alarm in expired_alarms])
)
for alarm in expired_alarms:
self.alarms.pop(alarm.name, None)
if alarm.static:
continue
for alarm in expired_alarm_records:
session.delete(alarm)
self._clear_alarm(alarm, session)
def _get_alarms(self) -> List[Alarm]:
return sorted(
@ -407,7 +409,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
self.publish_entities(self.alarms.values())
return ret
def transform_entities(self, entities: Collection[Alarm], **_) -> List[AlarmTable]:
def transform_entities(self, entities: Collection[Alarm], **_) -> List[DbAlarm]:
return [alarm.to_db() for alarm in entities]
def main(self):

View file

@ -6,9 +6,10 @@ import threading
from typing import Callable, Optional, Union
import croniter
from dateutil.tz import gettz
from platypush.context import get_bus, get_plugin
from platypush.entities.alarm import Alarm as AlarmTable
from platypush.entities.alarm import Alarm as AlarmDb
from platypush.message.request import Request
from platypush.message.event.alarm import (
AlarmStartedEvent,
@ -111,7 +112,10 @@ class Alarm:
try:
# If when is a cron expression, get the next run time
t = croniter.croniter(self.when, now).get_next()
t = croniter.croniter(
self.when,
datetime.datetime.fromtimestamp(now).replace(tzinfo=gettz()),
).get_next()
except (AttributeError, croniter.CroniterBadCronError):
try:
# If when is an ISO-8601 timestamp, parse it
@ -268,7 +272,7 @@ class Alarm:
}
@classmethod
def from_db(cls, alarm: AlarmTable, **kwargs) -> 'Alarm':
def from_db(cls, alarm: AlarmDb, **kwargs) -> 'Alarm':
return cls(
when=str(alarm.when),
name=str(alarm.name),
@ -282,8 +286,8 @@ class Alarm:
**kwargs,
)
def to_db(self) -> AlarmTable:
return AlarmTable(
def to_db(self) -> AlarmDb:
return AlarmDb(
id=self.name,
name=self.name,
when=self.when,