forked from platypush/platypush
[#340] Alarm integration improvements.
- Emit `EntityDeleteEvent` when an alarm is expired and removed so clients can properly synchronize their state. - `croniter.get_next()` should be timezone-aware.
This commit is contained in:
parent
fcb6b621ab
commit
9d5c755188
2 changed files with 25 additions and 19 deletions
|
@ -3,11 +3,12 @@ import sys
|
|||
from threading import RLock
|
||||
from typing import Collection, Generator, Optional, Dict, Any, List, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, make_transient
|
||||
|
||||
from platypush.context import get_plugin
|
||||
from platypush.entities import EntityManager
|
||||
from platypush.entities.alarm import Alarm as AlarmTable
|
||||
from platypush.entities.alarm import Alarm as DbAlarm
|
||||
from platypush.message.event.entities import EntityDeleteEvent
|
||||
from platypush.plugins import RunnablePlugin, action
|
||||
from platypush.plugins.db import DbPlugin
|
||||
from platypush.plugins.media import MediaPlugin
|
||||
|
@ -141,7 +142,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
|
|||
with self._db_lock, self._db.get_session() as session:
|
||||
yield session
|
||||
|
||||
def _merge_alarms(self, alarms: Dict[str, AlarmTable], session: Session):
|
||||
def _merge_alarms(self, alarms: Dict[str, DbAlarm], session: Session):
|
||||
for name, alarm in alarms.items():
|
||||
if name in self.alarms:
|
||||
existing_alarm = self.alarms[name]
|
||||
|
@ -154,7 +155,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
|
|||
# If the alarm record on the db is static, but the alarm is no
|
||||
# longer present in the configuration, then we want to delete it
|
||||
if alarm.static:
|
||||
session.delete(alarm)
|
||||
self._clear_alarm(alarm, session)
|
||||
else:
|
||||
self.alarms[name] = Alarm.from_db(
|
||||
alarm,
|
||||
|
@ -165,7 +166,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
|
|||
def _sync_alarms(self):
|
||||
with self._get_session() as session:
|
||||
db_alarms = {
|
||||
str(alarm.name): alarm for alarm in session.query(AlarmTable).all()
|
||||
str(alarm.name): alarm for alarm in session.query(DbAlarm).all()
|
||||
}
|
||||
|
||||
self._merge_alarms(db_alarms, session)
|
||||
|
@ -178,6 +179,12 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
|
|||
self.publish_entities(self.alarms.values())
|
||||
self._synced = True
|
||||
|
||||
def _clear_alarm(self, alarm: DbAlarm, session: Session):
|
||||
self.alarms.pop(str(alarm.name), None)
|
||||
session.delete(alarm)
|
||||
make_transient(alarm)
|
||||
self._bus.post(EntityDeleteEvent(entity=alarm))
|
||||
|
||||
def _clear_expired_alarms(self, session: Session):
|
||||
expired_alarms = [
|
||||
alarm
|
||||
|
@ -188,17 +195,12 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
|
|||
if not expired_alarms:
|
||||
return
|
||||
|
||||
expired_alarm_records = session.query(AlarmTable).filter(
|
||||
AlarmTable.name.in_([alarm.name for alarm in expired_alarms])
|
||||
expired_alarm_records = session.query(DbAlarm).filter(
|
||||
DbAlarm.name.in_([alarm.name for alarm in expired_alarms])
|
||||
)
|
||||
|
||||
for alarm in expired_alarms:
|
||||
self.alarms.pop(alarm.name, None)
|
||||
if alarm.static:
|
||||
continue
|
||||
|
||||
for alarm in expired_alarm_records:
|
||||
session.delete(alarm)
|
||||
self._clear_alarm(alarm, session)
|
||||
|
||||
def _get_alarms(self) -> List[Alarm]:
|
||||
return sorted(
|
||||
|
@ -407,7 +409,7 @@ class AlarmPlugin(RunnablePlugin, EntityManager):
|
|||
self.publish_entities(self.alarms.values())
|
||||
return ret
|
||||
|
||||
def transform_entities(self, entities: Collection[Alarm], **_) -> List[AlarmTable]:
|
||||
def transform_entities(self, entities: Collection[Alarm], **_) -> List[DbAlarm]:
|
||||
return [alarm.to_db() for alarm in entities]
|
||||
|
||||
def main(self):
|
||||
|
|
|
@ -6,9 +6,10 @@ import threading
|
|||
from typing import Callable, Optional, Union
|
||||
|
||||
import croniter
|
||||
from dateutil.tz import gettz
|
||||
|
||||
from platypush.context import get_bus, get_plugin
|
||||
from platypush.entities.alarm import Alarm as AlarmTable
|
||||
from platypush.entities.alarm import Alarm as AlarmDb
|
||||
from platypush.message.request import Request
|
||||
from platypush.message.event.alarm import (
|
||||
AlarmStartedEvent,
|
||||
|
@ -111,7 +112,10 @@ class Alarm:
|
|||
|
||||
try:
|
||||
# If when is a cron expression, get the next run time
|
||||
t = croniter.croniter(self.when, now).get_next()
|
||||
t = croniter.croniter(
|
||||
self.when,
|
||||
datetime.datetime.fromtimestamp(now).replace(tzinfo=gettz()),
|
||||
).get_next()
|
||||
except (AttributeError, croniter.CroniterBadCronError):
|
||||
try:
|
||||
# If when is an ISO-8601 timestamp, parse it
|
||||
|
@ -268,7 +272,7 @@ class Alarm:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, alarm: AlarmTable, **kwargs) -> 'Alarm':
|
||||
def from_db(cls, alarm: AlarmDb, **kwargs) -> 'Alarm':
|
||||
return cls(
|
||||
when=str(alarm.when),
|
||||
name=str(alarm.name),
|
||||
|
@ -282,8 +286,8 @@ class Alarm:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
def to_db(self) -> AlarmTable:
|
||||
return AlarmTable(
|
||||
def to_db(self) -> AlarmDb:
|
||||
return AlarmDb(
|
||||
id=self.name,
|
||||
name=self.name,
|
||||
when=self.when,
|
||||
|
|
Loading…
Reference in a new issue