forked from platypush/platypush
Replaced deprecated sqlalchemy.ext.declarative with sqlalchemy.orm
This commit is contained in:
parent
4b7eeaa4ed
commit
8a70f1d38e
7 changed files with 540 additions and 252 deletions
|
@ -3,8 +3,7 @@ import os
|
||||||
from typing import Optional, Union, List, Dict, Any
|
from typing import Optional, Union, List, Dict, Any
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime
|
from sqlalchemy import create_engine, Column, Integer, String, DateTime
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
from platypush.backend import Backend
|
from platypush.backend import Backend
|
||||||
from platypush.config import Config
|
from platypush.config import Config
|
||||||
|
@ -20,7 +19,7 @@ class Covid19Update(Base):
|
||||||
"""Models the Covid19Data table"""
|
"""Models the Covid19Data table"""
|
||||||
|
|
||||||
__tablename__ = 'covid19data'
|
__tablename__ = 'covid19data'
|
||||||
__table_args__ = ({'sqlite_autoincrement': True})
|
__table_args__ = {'sqlite_autoincrement': True}
|
||||||
|
|
||||||
country = Column(String, primary_key=True)
|
country = Column(String, primary_key=True)
|
||||||
confirmed = Column(Integer, nullable=False, default=0)
|
confirmed = Column(Integer, nullable=False, default=0)
|
||||||
|
@ -40,7 +39,12 @@ class Covid19Backend(Backend):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
def __init__(self, country: Optional[Union[str, List[str]]], poll_seconds: Optional[float] = 3600.0, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
country: Optional[Union[str, List[str]]],
|
||||||
|
poll_seconds: Optional[float] = 3600.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
:param country: Default country (or list of countries) to retrieve the stats for. It can either be the full
|
:param country: Default country (or list of countries) to retrieve the stats for. It can either be the full
|
||||||
country name or the country code. Special values:
|
country name or the country code. Special values:
|
||||||
|
@ -56,7 +60,9 @@ class Covid19Backend(Backend):
|
||||||
super().__init__(poll_seconds=poll_seconds, **kwargs)
|
super().__init__(poll_seconds=poll_seconds, **kwargs)
|
||||||
self._plugin: Covid19Plugin = get_plugin('covid19')
|
self._plugin: Covid19Plugin = get_plugin('covid19')
|
||||||
self.country: List[str] = self._plugin._get_countries(country)
|
self.country: List[str] = self._plugin._get_countries(country)
|
||||||
self.workdir = os.path.join(os.path.expanduser(Config.get('workdir')), 'covid19')
|
self.workdir = os.path.join(
|
||||||
|
os.path.expanduser(Config.get('workdir')), 'covid19'
|
||||||
|
)
|
||||||
self.dbfile = os.path.join(self.workdir, 'data.db')
|
self.dbfile = os.path.join(self.workdir, 'data.db')
|
||||||
os.makedirs(self.workdir, exist_ok=True)
|
os.makedirs(self.workdir, exist_ok=True)
|
||||||
|
|
||||||
|
@ -67,22 +73,30 @@ class Covid19Backend(Backend):
|
||||||
self.logger.info('Stopped Covid19 backend')
|
self.logger.info('Stopped Covid19 backend')
|
||||||
|
|
||||||
def _process_update(self, summary: Dict[str, Any], session: Session):
|
def _process_update(self, summary: Dict[str, Any], session: Session):
|
||||||
update_time = datetime.datetime.fromisoformat(summary['Date'].replace('Z', '+00:00'))
|
update_time = datetime.datetime.fromisoformat(
|
||||||
|
summary['Date'].replace('Z', '+00:00')
|
||||||
|
)
|
||||||
|
|
||||||
self.bus.post(Covid19UpdateEvent(
|
self.bus.post(
|
||||||
|
Covid19UpdateEvent(
|
||||||
country=summary['Country'],
|
country=summary['Country'],
|
||||||
country_code=summary['CountryCode'],
|
country_code=summary['CountryCode'],
|
||||||
confirmed=summary['TotalConfirmed'],
|
confirmed=summary['TotalConfirmed'],
|
||||||
deaths=summary['TotalDeaths'],
|
deaths=summary['TotalDeaths'],
|
||||||
recovered=summary['TotalRecovered'],
|
recovered=summary['TotalRecovered'],
|
||||||
update_time=update_time,
|
update_time=update_time,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
session.merge(Covid19Update(country=summary['CountryCode'],
|
session.merge(
|
||||||
|
Covid19Update(
|
||||||
|
country=summary['CountryCode'],
|
||||||
confirmed=summary['TotalConfirmed'],
|
confirmed=summary['TotalConfirmed'],
|
||||||
deaths=summary['TotalDeaths'],
|
deaths=summary['TotalDeaths'],
|
||||||
recovered=summary['TotalRecovered'],
|
recovered=summary['TotalRecovered'],
|
||||||
last_updated_at=update_time))
|
last_updated_at=update_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def loop(self):
|
def loop(self):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
|
@ -90,23 +104,30 @@ class Covid19Backend(Backend):
|
||||||
if not summaries:
|
if not summaries:
|
||||||
return
|
return
|
||||||
|
|
||||||
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
|
engine = create_engine(
|
||||||
|
'sqlite:///{}'.format(self.dbfile),
|
||||||
|
connect_args={'check_same_thread': False},
|
||||||
|
)
|
||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
Session.configure(bind=engine)
|
Session.configure(bind=engine)
|
||||||
session = Session()
|
session = Session()
|
||||||
|
|
||||||
last_records = {
|
last_records = {
|
||||||
record.country: record
|
record.country: record
|
||||||
for record in session.query(Covid19Update).filter(Covid19Update.country.in_(self.country)).all()
|
for record in session.query(Covid19Update)
|
||||||
|
.filter(Covid19Update.country.in_(self.country))
|
||||||
|
.all()
|
||||||
}
|
}
|
||||||
|
|
||||||
for summary in summaries:
|
for summary in summaries:
|
||||||
country = summary['CountryCode']
|
country = summary['CountryCode']
|
||||||
last_record = last_records.get(country)
|
last_record = last_records.get(country)
|
||||||
if not last_record or \
|
if (
|
||||||
summary['TotalConfirmed'] != last_record.confirmed or \
|
not last_record
|
||||||
summary['TotalDeaths'] != last_record.deaths or \
|
or summary['TotalConfirmed'] != last_record.confirmed
|
||||||
summary['TotalRecovered'] != last_record.recovered:
|
or summary['TotalDeaths'] != last_record.deaths
|
||||||
|
or summary['TotalRecovered'] != last_record.recovered
|
||||||
|
):
|
||||||
self._process_update(summary=summary, session=session)
|
self._process_update(summary=summary, session=session)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
|
@ -6,15 +6,28 @@ from typing import Optional, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from sqlalchemy import create_engine, Column, String, DateTime
|
from sqlalchemy import create_engine, Column, String, DateTime
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
|
||||||
|
|
||||||
from platypush.backend import Backend
|
from platypush.backend import Backend
|
||||||
from platypush.config import Config
|
from platypush.config import Config
|
||||||
from platypush.message.event.github import GithubPushEvent, GithubCommitCommentEvent, GithubCreateEvent, \
|
from platypush.message.event.github import (
|
||||||
GithubDeleteEvent, GithubEvent, GithubForkEvent, GithubWikiEvent, GithubIssueCommentEvent, GithubIssueEvent, \
|
GithubPushEvent,
|
||||||
GithubMemberEvent, GithubPublicEvent, GithubPullRequestEvent, GithubPullRequestReviewCommentEvent, \
|
GithubCommitCommentEvent,
|
||||||
GithubReleaseEvent, GithubSponsorshipEvent, GithubWatchEvent
|
GithubCreateEvent,
|
||||||
|
GithubDeleteEvent,
|
||||||
|
GithubEvent,
|
||||||
|
GithubForkEvent,
|
||||||
|
GithubWikiEvent,
|
||||||
|
GithubIssueCommentEvent,
|
||||||
|
GithubIssueEvent,
|
||||||
|
GithubMemberEvent,
|
||||||
|
GithubPublicEvent,
|
||||||
|
GithubPullRequestEvent,
|
||||||
|
GithubPullRequestReviewCommentEvent,
|
||||||
|
GithubReleaseEvent,
|
||||||
|
GithubSponsorshipEvent,
|
||||||
|
GithubWatchEvent,
|
||||||
|
)
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
Session = scoped_session(sessionmaker())
|
Session = scoped_session(sessionmaker())
|
||||||
|
@ -71,8 +84,17 @@ class GithubBackend(Backend):
|
||||||
|
|
||||||
_base_url = 'https://api.github.com'
|
_base_url = 'https://api.github.com'
|
||||||
|
|
||||||
def __init__(self, user: str, user_token: str, repos: Optional[List[str]] = None, org: Optional[str] = None,
|
def __init__(
|
||||||
poll_seconds: int = 60, max_events_per_scan: Optional[int] = 10, *args, **kwargs):
|
self,
|
||||||
|
user: str,
|
||||||
|
user_token: str,
|
||||||
|
repos: Optional[List[str]] = None,
|
||||||
|
org: Optional[str] = None,
|
||||||
|
poll_seconds: int = 60,
|
||||||
|
max_events_per_scan: Optional[int] = 10,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
If neither ``repos`` nor ``org`` is specified then the backend will monitor all new events on user level.
|
If neither ``repos`` nor ``org`` is specified then the backend will monitor all new events on user level.
|
||||||
|
|
||||||
|
@ -102,11 +124,17 @@ class GithubBackend(Backend):
|
||||||
|
|
||||||
def _request(self, uri: str, method: str = 'get') -> dict:
|
def _request(self, uri: str, method: str = 'get') -> dict:
|
||||||
method = getattr(requests, method.lower())
|
method = getattr(requests, method.lower())
|
||||||
return method(self._base_url + uri, auth=(self.user, self.user_token),
|
return method(
|
||||||
headers={'Accept': 'application/vnd.github.v3+json'}).json()
|
self._base_url + uri,
|
||||||
|
auth=(self.user, self.user_token),
|
||||||
|
headers={'Accept': 'application/vnd.github.v3+json'},
|
||||||
|
).json()
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self):
|
||||||
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
|
engine = create_engine(
|
||||||
|
'sqlite:///{}'.format(self.dbfile),
|
||||||
|
connect_args={'check_same_thread': False},
|
||||||
|
)
|
||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
Session.configure(bind=engine)
|
Session.configure(bind=engine)
|
||||||
|
|
||||||
|
@ -128,7 +156,11 @@ class GithubBackend(Backend):
|
||||||
def _get_last_event_time(self, uri: str):
|
def _get_last_event_time(self, uri: str):
|
||||||
with self.db_lock:
|
with self.db_lock:
|
||||||
record = self._get_or_create_resource(uri=uri, session=Session())
|
record = self._get_or_create_resource(uri=uri, session=Session())
|
||||||
return record.last_updated_at.replace(tzinfo=datetime.timezone.utc) if record.last_updated_at else None
|
return (
|
||||||
|
record.last_updated_at.replace(tzinfo=datetime.timezone.utc)
|
||||||
|
if record.last_updated_at
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def _update_last_event_time(self, uri: str, last_updated_at: datetime.datetime):
|
def _update_last_event_time(self, uri: str, last_updated_at: datetime.datetime):
|
||||||
with self.db_lock:
|
with self.db_lock:
|
||||||
|
@ -158,9 +190,18 @@ class GithubBackend(Backend):
|
||||||
'WatchEvent': GithubWatchEvent,
|
'WatchEvent': GithubWatchEvent,
|
||||||
}
|
}
|
||||||
|
|
||||||
event_type = event_mapping[event['type']] if event['type'] in event_mapping else GithubEvent
|
event_type = (
|
||||||
return event_type(event_type=event['type'], actor=event['actor'], repo=event.get('repo', {}),
|
event_mapping[event['type']]
|
||||||
payload=event['payload'], created_at=cls._to_datetime(event['created_at']))
|
if event['type'] in event_mapping
|
||||||
|
else GithubEvent
|
||||||
|
)
|
||||||
|
return event_type(
|
||||||
|
event_type=event['type'],
|
||||||
|
actor=event['actor'],
|
||||||
|
repo=event.get('repo', {}),
|
||||||
|
payload=event['payload'],
|
||||||
|
created_at=cls._to_datetime(event['created_at']),
|
||||||
|
)
|
||||||
|
|
||||||
def _events_monitor(self, uri: str, method: str = 'get'):
|
def _events_monitor(self, uri: str, method: str = 'get'):
|
||||||
def thread():
|
def thread():
|
||||||
|
@ -175,7 +216,10 @@ class GithubBackend(Backend):
|
||||||
fired_events = []
|
fired_events = []
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
if self.max_events_per_scan and len(fired_events) >= self.max_events_per_scan:
|
if (
|
||||||
|
self.max_events_per_scan
|
||||||
|
and len(fired_events) >= self.max_events_per_scan
|
||||||
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
event_time = self._to_datetime(event['created_at'])
|
event_time = self._to_datetime(event['created_at'])
|
||||||
|
@ -189,12 +233,17 @@ class GithubBackend(Backend):
|
||||||
for event in fired_events:
|
for event in fired_events:
|
||||||
self.bus.post(event)
|
self.bus.post(event)
|
||||||
|
|
||||||
self._update_last_event_time(uri=uri, last_updated_at=new_last_event_time)
|
self._update_last_event_time(
|
||||||
|
uri=uri, last_updated_at=new_last_event_time
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning('Encountered exception while fetching events from {}: {}'.format(
|
self.logger.warning(
|
||||||
uri, str(e)))
|
'Encountered exception while fetching events from {}: {}'.format(
|
||||||
|
uri, str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
self.logger.exception(e)
|
self.logger.exception(e)
|
||||||
finally:
|
|
||||||
if self.wait_stop(timeout=self.poll_seconds):
|
if self.wait_stop(timeout=self.poll_seconds):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -206,12 +255,30 @@ class GithubBackend(Backend):
|
||||||
|
|
||||||
if self.repos:
|
if self.repos:
|
||||||
for repo in self.repos:
|
for repo in self.repos:
|
||||||
monitors.append(threading.Thread(target=self._events_monitor('/networks/{repo}/events'.format(repo=repo))))
|
monitors.append(
|
||||||
|
threading.Thread(
|
||||||
|
target=self._events_monitor(
|
||||||
|
'/networks/{repo}/events'.format(repo=repo)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
if self.org:
|
if self.org:
|
||||||
monitors.append(threading.Thread(target=self._events_monitor('/orgs/{org}/events'.format(org=self.org))))
|
monitors.append(
|
||||||
|
threading.Thread(
|
||||||
|
target=self._events_monitor(
|
||||||
|
'/orgs/{org}/events'.format(org=self.org)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if not (self.repos or self.org):
|
if not (self.repos or self.org):
|
||||||
monitors.append(threading.Thread(target=self._events_monitor('/users/{user}/events'.format(user=self.user))))
|
monitors.append(
|
||||||
|
threading.Thread(
|
||||||
|
target=self._events_monitor(
|
||||||
|
'/users/{user}/events'.format(user=self.user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for monitor in monitors:
|
for monitor in monitors:
|
||||||
monitor.start()
|
monitor.start()
|
||||||
|
@ -222,4 +289,5 @@ class GithubBackend(Backend):
|
||||||
|
|
||||||
self.logger.info('Github backend terminated')
|
self.logger.info('Github backend terminated')
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
|
@ -2,11 +2,17 @@ import datetime
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, \
|
from sqlalchemy import (
|
||||||
Enum, ForeignKey
|
create_engine,
|
||||||
|
Column,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
ForeignKey,
|
||||||
|
)
|
||||||
|
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.sql.expression import func
|
from sqlalchemy.sql.expression import func
|
||||||
|
|
||||||
from platypush.backend.http.request import HttpRequest
|
from platypush.backend.http.request import HttpRequest
|
||||||
|
@ -44,18 +50,31 @@ class RssUpdates(HttpRequest):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user_agent = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) ' + \
|
user_agent = (
|
||||||
'Chrome/62.0.3202.94 Safari/537.36'
|
'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) '
|
||||||
|
+ 'Chrome/62.0.3202.94 Safari/537.36'
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, url, title=None, headers=None, params=None, max_entries=None,
|
def __init__(
|
||||||
extract_content=False, digest_format=None, user_agent: str = user_agent,
|
self,
|
||||||
body_style: str = 'font-size: 22px; ' +
|
url,
|
||||||
'font-family: "Merriweather", Georgia, "Times New Roman", Times, serif;',
|
title=None,
|
||||||
|
headers=None,
|
||||||
|
params=None,
|
||||||
|
max_entries=None,
|
||||||
|
extract_content=False,
|
||||||
|
digest_format=None,
|
||||||
|
user_agent: str = user_agent,
|
||||||
|
body_style: str = 'font-size: 22px; '
|
||||||
|
+ 'font-family: "Merriweather", Georgia, "Times New Roman", Times, serif;',
|
||||||
title_style: str = 'margin-top: 30px',
|
title_style: str = 'margin-top: 30px',
|
||||||
subtitle_style: str = 'margin-top: 10px; page-break-after: always',
|
subtitle_style: str = 'margin-top: 10px; page-break-after: always',
|
||||||
article_title_style: str = 'page-break-before: always',
|
article_title_style: str = 'page-break-before: always',
|
||||||
article_link_style: str = 'color: #555; text-decoration: none; border-bottom: 1px dotted',
|
article_link_style: str = 'color: #555; text-decoration: none; border-bottom: 1px dotted',
|
||||||
article_content_style: str = '', *argv, **kwargs):
|
article_content_style: str = '',
|
||||||
|
*argv,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
:param url: URL to the RSS feed to be monitored.
|
:param url: URL to the RSS feed to be monitored.
|
||||||
:param title: Optional title for the feed.
|
:param title: Optional title for the feed.
|
||||||
|
@ -91,7 +110,9 @@ class RssUpdates(HttpRequest):
|
||||||
# If true, then the http.webpage plugin will be used to parse the content
|
# If true, then the http.webpage plugin will be used to parse the content
|
||||||
self.extract_content = extract_content
|
self.extract_content = extract_content
|
||||||
|
|
||||||
self.digest_format = digest_format.lower() if digest_format else None # Supported formats: html, pdf
|
self.digest_format = (
|
||||||
|
digest_format.lower() if digest_format else None
|
||||||
|
) # Supported formats: html, pdf
|
||||||
|
|
||||||
os.makedirs(os.path.expanduser(os.path.dirname(self.dbfile)), exist_ok=True)
|
os.makedirs(os.path.expanduser(os.path.dirname(self.dbfile)), exist_ok=True)
|
||||||
|
|
||||||
|
@ -119,7 +140,11 @@ class RssUpdates(HttpRequest):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_latest_update(session, source_id):
|
def _get_latest_update(session, source_id):
|
||||||
return session.query(func.max(FeedEntry.published)).filter_by(source_id=source_id).scalar()
|
return (
|
||||||
|
session.query(func.max(FeedEntry.published))
|
||||||
|
.filter_by(source_id=source_id)
|
||||||
|
.scalar()
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_entry_content(self, link):
|
def _parse_entry_content(self, link):
|
||||||
self.logger.info('Extracting content from {}'.format(link))
|
self.logger.info('Extracting content from {}'.format(link))
|
||||||
|
@ -130,14 +155,20 @@ class RssUpdates(HttpRequest):
|
||||||
errors = response.errors
|
errors = response.errors
|
||||||
|
|
||||||
if not output:
|
if not output:
|
||||||
self.logger.warning('Mercury parser error: {}'.format(errors or '[unknown error]'))
|
self.logger.warning(
|
||||||
|
'Mercury parser error: {}'.format(errors or '[unknown error]')
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
return output.get('content')
|
return output.get('content')
|
||||||
|
|
||||||
def get_new_items(self, response):
|
def get_new_items(self, response):
|
||||||
import feedparser
|
import feedparser
|
||||||
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
|
|
||||||
|
engine = create_engine(
|
||||||
|
'sqlite:///{}'.format(self.dbfile),
|
||||||
|
connect_args={'check_same_thread': False},
|
||||||
|
)
|
||||||
|
|
||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
Session.configure(bind=engine)
|
Session.configure(bind=engine)
|
||||||
|
@ -157,12 +188,16 @@ class RssUpdates(HttpRequest):
|
||||||
|
|
||||||
content = u'''
|
content = u'''
|
||||||
<h1 style="{title_style}">{title}</h1>
|
<h1 style="{title_style}">{title}</h1>
|
||||||
<h2 style="{subtitle_style}">Feeds digest generated on {creation_date}</h2>'''.\
|
<h2 style="{subtitle_style}">Feeds digest generated on {creation_date}</h2>'''.format(
|
||||||
format(title_style=self.title_style, title=self.title, subtitle_style=self.subtitle_style,
|
title_style=self.title_style,
|
||||||
creation_date=datetime.datetime.now().strftime('%d %B %Y, %H:%M'))
|
title=self.title,
|
||||||
|
subtitle_style=self.subtitle_style,
|
||||||
|
creation_date=datetime.datetime.now().strftime('%d %B %Y, %H:%M'),
|
||||||
|
)
|
||||||
|
|
||||||
self.logger.info('Parsed {:d} items from RSS feed <{}>'
|
self.logger.info(
|
||||||
.format(len(feed.entries), self.url))
|
'Parsed {:d} items from RSS feed <{}>'.format(len(feed.entries), self.url)
|
||||||
|
)
|
||||||
|
|
||||||
for entry in feed.entries:
|
for entry in feed.entries:
|
||||||
if not entry.published_parsed:
|
if not entry.published_parsed:
|
||||||
|
@ -171,9 +206,10 @@ class RssUpdates(HttpRequest):
|
||||||
try:
|
try:
|
||||||
entry_timestamp = datetime.datetime(*entry.published_parsed[:6])
|
entry_timestamp = datetime.datetime(*entry.published_parsed[:6])
|
||||||
|
|
||||||
if latest_update is None \
|
if latest_update is None or entry_timestamp > latest_update:
|
||||||
or entry_timestamp > latest_update:
|
self.logger.info(
|
||||||
self.logger.info('Processed new item from RSS feed <{}>'.format(self.url))
|
'Processed new item from RSS feed <{}>'.format(self.url)
|
||||||
|
)
|
||||||
entry.summary = entry.summary if hasattr(entry, 'summary') else None
|
entry.summary = entry.summary if hasattr(entry, 'summary') else None
|
||||||
|
|
||||||
if self.extract_content:
|
if self.extract_content:
|
||||||
|
@ -188,9 +224,13 @@ class RssUpdates(HttpRequest):
|
||||||
<a href="{link}" target="_blank" style="{article_link_style}">{title}</a>
|
<a href="{link}" target="_blank" style="{article_link_style}">{title}</a>
|
||||||
</h1>
|
</h1>
|
||||||
<div class="_parsed-content" style="{article_content_style}">{content}</div>'''.format(
|
<div class="_parsed-content" style="{article_content_style}">{content}</div>'''.format(
|
||||||
article_title_style=self.article_title_style, article_link_style=self.article_link_style,
|
article_title_style=self.article_title_style,
|
||||||
article_content_style=self.article_content_style, link=entry.link, title=entry.title,
|
article_link_style=self.article_link_style,
|
||||||
content=entry.content)
|
article_content_style=self.article_content_style,
|
||||||
|
link=entry.link,
|
||||||
|
title=entry.title,
|
||||||
|
content=entry.content,
|
||||||
|
)
|
||||||
|
|
||||||
e = {
|
e = {
|
||||||
'entry_id': entry.id,
|
'entry_id': entry.id,
|
||||||
|
@ -207,21 +247,32 @@ class RssUpdates(HttpRequest):
|
||||||
if self.max_entries and len(entries) > self.max_entries:
|
if self.max_entries and len(entries) > self.max_entries:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning('Exception encountered while parsing RSS ' +
|
self.logger.warning(
|
||||||
'RSS feed {}: {}'.format(entry.link, str(e)))
|
'Exception encountered while parsing RSS '
|
||||||
|
+ f'RSS feed {entry.link}: {e}'
|
||||||
|
)
|
||||||
self.logger.exception(e)
|
self.logger.exception(e)
|
||||||
|
|
||||||
source_record.last_updated_at = parse_start_time
|
source_record.last_updated_at = parse_start_time
|
||||||
digest_filename = None
|
digest_filename = None
|
||||||
|
|
||||||
if entries:
|
if entries:
|
||||||
self.logger.info('Parsed {} new entries from the RSS feed {}'.format(
|
self.logger.info(
|
||||||
len(entries), self.title))
|
'Parsed {} new entries from the RSS feed {}'.format(
|
||||||
|
len(entries), self.title
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if self.digest_format:
|
if self.digest_format:
|
||||||
digest_filename = os.path.join(self.workdir, 'cache', '{}_{}.{}'.format(
|
digest_filename = os.path.join(
|
||||||
|
self.workdir,
|
||||||
|
'cache',
|
||||||
|
'{}_{}.{}'.format(
|
||||||
datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'),
|
datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'),
|
||||||
self.title, self.digest_format))
|
self.title,
|
||||||
|
self.digest_format,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(digest_filename), exist_ok=True)
|
os.makedirs(os.path.dirname(digest_filename), exist_ok=True)
|
||||||
|
|
||||||
|
@ -233,12 +284,15 @@ class RssUpdates(HttpRequest):
|
||||||
</head>
|
</head>
|
||||||
<body style="{body_style}">{content}</body>
|
<body style="{body_style}">{content}</body>
|
||||||
</html>
|
</html>
|
||||||
'''.format(title=self.title, body_style=self.body_style, content=content)
|
'''.format(
|
||||||
|
title=self.title, body_style=self.body_style, content=content
|
||||||
|
)
|
||||||
|
|
||||||
with open(digest_filename, 'w', encoding='utf-8') as f:
|
with open(digest_filename, 'w', encoding='utf-8') as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
elif self.digest_format == 'pdf':
|
elif self.digest_format == 'pdf':
|
||||||
from weasyprint import HTML, CSS
|
from weasyprint import HTML, CSS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from weasyprint.fonts import FontConfiguration
|
from weasyprint.fonts import FontConfiguration
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -246,37 +300,47 @@ class RssUpdates(HttpRequest):
|
||||||
|
|
||||||
body_style = 'body { ' + self.body_style + ' }'
|
body_style = 'body { ' + self.body_style + ' }'
|
||||||
font_config = FontConfiguration()
|
font_config = FontConfiguration()
|
||||||
css = [CSS('https://fonts.googleapis.com/css?family=Merriweather'),
|
css = [
|
||||||
CSS(string=body_style, font_config=font_config)]
|
CSS('https://fonts.googleapis.com/css?family=Merriweather'),
|
||||||
|
CSS(string=body_style, font_config=font_config),
|
||||||
|
]
|
||||||
|
|
||||||
HTML(string=content).write_pdf(digest_filename, stylesheets=css)
|
HTML(string=content).write_pdf(digest_filename, stylesheets=css)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('Unsupported format: {}. Supported formats: ' +
|
raise RuntimeError(
|
||||||
'html or pdf'.format(self.digest_format))
|
f'Unsupported format: {self.digest_format}. Supported formats: html, pdf'
|
||||||
|
)
|
||||||
|
|
||||||
digest_entry = FeedDigest(source_id=source_record.id,
|
digest_entry = FeedDigest(
|
||||||
|
source_id=source_record.id,
|
||||||
format=self.digest_format,
|
format=self.digest_format,
|
||||||
filename=digest_filename)
|
filename=digest_filename,
|
||||||
|
)
|
||||||
|
|
||||||
session.add(digest_entry)
|
session.add(digest_entry)
|
||||||
self.logger.info('{} digest ready: {}'.format(self.digest_format, digest_filename))
|
self.logger.info(
|
||||||
|
'{} digest ready: {}'.format(self.digest_format, digest_filename)
|
||||||
|
)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
self.logger.info('Parsing RSS feed {}: completed'.format(self.title))
|
self.logger.info('Parsing RSS feed {}: completed'.format(self.title))
|
||||||
|
|
||||||
return NewFeedEvent(request=dict(self), response=entries,
|
return NewFeedEvent(
|
||||||
|
request=dict(self),
|
||||||
|
response=entries,
|
||||||
source_id=source_record.id,
|
source_id=source_record.id,
|
||||||
source_title=source_record.title,
|
source_title=source_record.title,
|
||||||
title=self.title,
|
title=self.title,
|
||||||
digest_format=self.digest_format,
|
digest_format=self.digest_format,
|
||||||
digest_filename=digest_filename)
|
digest_filename=digest_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FeedSource(Base):
|
class FeedSource(Base):
|
||||||
"""Models the FeedSource table, containing RSS sources to be parsed"""
|
"""Models the FeedSource table, containing RSS sources to be parsed"""
|
||||||
|
|
||||||
__tablename__ = 'FeedSource'
|
__tablename__ = 'FeedSource'
|
||||||
__table_args__ = ({'sqlite_autoincrement': True})
|
__table_args__ = {'sqlite_autoincrement': True}
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
title = Column(String)
|
title = Column(String)
|
||||||
|
@ -288,7 +352,7 @@ class FeedEntry(Base):
|
||||||
"""Models the FeedEntry table, which contains RSS entries"""
|
"""Models the FeedEntry table, which contains RSS entries"""
|
||||||
|
|
||||||
__tablename__ = 'FeedEntry'
|
__tablename__ = 'FeedEntry'
|
||||||
__table_args__ = ({'sqlite_autoincrement': True})
|
__table_args__ = {'sqlite_autoincrement': True}
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
entry_id = Column(String)
|
entry_id = Column(String)
|
||||||
|
@ -309,7 +373,7 @@ class FeedDigest(Base):
|
||||||
pdf = 2
|
pdf = 2
|
||||||
|
|
||||||
__tablename__ = 'FeedDigest'
|
__tablename__ = 'FeedDigest'
|
||||||
__table_args__ = ({'sqlite_autoincrement': True})
|
__table_args__ = {'sqlite_autoincrement': True}
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
source_id = Column(Integer, ForeignKey('FeedSource.id'), nullable=False)
|
source_id = Column(Integer, ForeignKey('FeedSource.id'), nullable=False)
|
||||||
|
@ -317,4 +381,5 @@ class FeedDigest(Base):
|
||||||
filename = Column(String, nullable=False)
|
filename = Column(String, nullable=False)
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.utcnow)
|
created_at = Column(DateTime, nullable=False, default=datetime.datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
|
@ -8,15 +8,18 @@ from queue import Queue, Empty
|
||||||
from threading import Thread, RLock
|
from threading import Thread, RLock
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime
|
from sqlalchemy import engine, create_engine, Column, Integer, String, DateTime
|
||||||
import sqlalchemy.engine as engine
|
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
from platypush.backend import Backend
|
from platypush.backend import Backend
|
||||||
from platypush.config import Config
|
from platypush.config import Config
|
||||||
from platypush.context import get_plugin
|
from platypush.context import get_plugin
|
||||||
from platypush.message.event.mail import MailReceivedEvent, MailSeenEvent, MailFlaggedEvent, MailUnflaggedEvent
|
from platypush.message.event.mail import (
|
||||||
|
MailReceivedEvent,
|
||||||
|
MailSeenEvent,
|
||||||
|
MailFlaggedEvent,
|
||||||
|
MailUnflaggedEvent,
|
||||||
|
)
|
||||||
from platypush.plugins.mail import MailInPlugin, Mail
|
from platypush.plugins.mail import MailInPlugin, Mail
|
||||||
|
|
||||||
# <editor-fold desc="Database tables">
|
# <editor-fold desc="Database tables">
|
||||||
|
@ -26,6 +29,7 @@ Session = scoped_session(sessionmaker())
|
||||||
|
|
||||||
class MailboxStatus(Base):
|
class MailboxStatus(Base):
|
||||||
"""Models the MailboxStatus table, containing information about the state of a monitored mailbox."""
|
"""Models the MailboxStatus table, containing information about the state of a monitored mailbox."""
|
||||||
|
|
||||||
__tablename__ = 'MailboxStatus'
|
__tablename__ = 'MailboxStatus'
|
||||||
|
|
||||||
mailbox_id = Column(Integer, primary_key=True)
|
mailbox_id = Column(Integer, primary_key=True)
|
||||||
|
@ -64,8 +68,13 @@ class MailBackend(Backend):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mailboxes: List[Dict[str, Any]], timeout: Optional[int] = 60, poll_seconds: Optional[int] = 60,
|
def __init__(
|
||||||
**kwargs):
|
self,
|
||||||
|
mailboxes: List[Dict[str, Any]],
|
||||||
|
timeout: Optional[int] = 60,
|
||||||
|
poll_seconds: Optional[int] = 60,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
:param mailboxes: List of mailboxes to be monitored. Each mailbox entry contains a ``plugin`` attribute to
|
:param mailboxes: List of mailboxes to be monitored. Each mailbox entry contains a ``plugin`` attribute to
|
||||||
identify the :class:`platypush.plugins.mail.MailInPlugin` plugin that will be used (e.g. ``mail.imap``)
|
identify the :class:`platypush.plugins.mail.MailInPlugin` plugin that will be used (e.g. ``mail.imap``)
|
||||||
|
@ -128,9 +137,13 @@ class MailBackend(Backend):
|
||||||
|
|
||||||
# Parse mailboxes
|
# Parse mailboxes
|
||||||
for i, mbox in enumerate(mailboxes):
|
for i, mbox in enumerate(mailboxes):
|
||||||
assert 'plugin' in mbox, 'No plugin attribute specified for mailbox n.{}'.format(i)
|
assert (
|
||||||
|
'plugin' in mbox
|
||||||
|
), 'No plugin attribute specified for mailbox n.{}'.format(i)
|
||||||
plugin = get_plugin(mbox.pop('plugin'))
|
plugin = get_plugin(mbox.pop('plugin'))
|
||||||
assert isinstance(plugin, MailInPlugin), '{} is not a MailInPlugin'.format(plugin)
|
assert isinstance(plugin, MailInPlugin), '{} is not a MailInPlugin'.format(
|
||||||
|
plugin
|
||||||
|
)
|
||||||
name = mbox.pop('name') if 'name' in mbox else 'Mailbox #{}'.format(i + 1)
|
name = mbox.pop('name') if 'name' in mbox else 'Mailbox #{}'.format(i + 1)
|
||||||
self.mailboxes.append(Mailbox(plugin=plugin, name=name, args=mbox))
|
self.mailboxes.append(Mailbox(plugin=plugin, name=name, args=mbox))
|
||||||
|
|
||||||
|
@ -144,7 +157,10 @@ class MailBackend(Backend):
|
||||||
|
|
||||||
# <editor-fold desc="Database methods">
|
# <editor-fold desc="Database methods">
|
||||||
def _db_get_engine(self) -> engine.Engine:
|
def _db_get_engine(self) -> engine.Engine:
|
||||||
return create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
|
return create_engine(
|
||||||
|
'sqlite:///{}'.format(self.dbfile),
|
||||||
|
connect_args={'check_same_thread': False},
|
||||||
|
)
|
||||||
|
|
||||||
def _db_load_mailboxes_status(self) -> None:
|
def _db_load_mailboxes_status(self) -> None:
|
||||||
mailbox_ids = list(range(len(self.mailboxes)))
|
mailbox_ids = list(range(len(self.mailboxes)))
|
||||||
|
@ -153,12 +169,18 @@ class MailBackend(Backend):
|
||||||
session = Session()
|
session = Session()
|
||||||
records = {
|
records = {
|
||||||
record.mailbox_id: record
|
record.mailbox_id: record
|
||||||
for record in session.query(MailboxStatus).filter(MailboxStatus.mailbox_id.in_(mailbox_ids)).all()
|
for record in session.query(MailboxStatus)
|
||||||
|
.filter(MailboxStatus.mailbox_id.in_(mailbox_ids))
|
||||||
|
.all()
|
||||||
}
|
}
|
||||||
|
|
||||||
for mbox_id, mbox in enumerate(self.mailboxes):
|
for mbox_id, _ in enumerate(self.mailboxes):
|
||||||
if mbox_id not in records:
|
if mbox_id not in records:
|
||||||
record = MailboxStatus(mailbox_id=mbox_id, unseen_message_ids='[]', flagged_message_ids='[]')
|
record = MailboxStatus(
|
||||||
|
mailbox_id=mbox_id,
|
||||||
|
unseen_message_ids='[]',
|
||||||
|
flagged_message_ids='[]',
|
||||||
|
)
|
||||||
session.add(record)
|
session.add(record)
|
||||||
else:
|
else:
|
||||||
record = records[mbox_id]
|
record = records[mbox_id]
|
||||||
|
@ -170,19 +192,25 @@ class MailBackend(Backend):
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def _db_get_mailbox_status(self, mailbox_ids: List[int]) -> Dict[int, MailboxStatus]:
|
def _db_get_mailbox_status(
|
||||||
|
self, mailbox_ids: List[int]
|
||||||
|
) -> Dict[int, MailboxStatus]:
|
||||||
with self._db_lock:
|
with self._db_lock:
|
||||||
session = Session()
|
session = Session()
|
||||||
return {
|
return {
|
||||||
record.mailbox_id: record
|
record.mailbox_id: record
|
||||||
for record in session.query(MailboxStatus).filter(MailboxStatus.mailbox_id.in_(mailbox_ids)).all()
|
for record in session.query(MailboxStatus)
|
||||||
|
.filter(MailboxStatus.mailbox_id.in_(mailbox_ids))
|
||||||
|
.all()
|
||||||
}
|
}
|
||||||
|
|
||||||
# </editor-fold>
|
# </editor-fold>
|
||||||
|
|
||||||
# <editor-fold desc="Parse unread messages logic">
|
# <editor-fold desc="Parse unread messages logic">
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_thread(unread_queue: Queue, flagged_queue: Queue, plugin: MailInPlugin, **args):
|
def _check_thread(
|
||||||
|
unread_queue: Queue, flagged_queue: Queue, plugin: MailInPlugin, **args
|
||||||
|
):
|
||||||
def thread():
|
def thread():
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
unread = plugin.search_unseen_messages(**args).output
|
unread = plugin.search_unseen_messages(**args).output
|
||||||
|
@ -194,8 +222,9 @@ class MailBackend(Backend):
|
||||||
|
|
||||||
return thread
|
return thread
|
||||||
|
|
||||||
def _get_unread_seen_msgs(self, mailbox_idx: int, unread_msgs: Dict[int, Mail]) \
|
def _get_unread_seen_msgs(
|
||||||
-> Tuple[Dict[int, Mail], Dict[int, Mail]]:
|
self, mailbox_idx: int, unread_msgs: Dict[int, Mail]
|
||||||
|
) -> Tuple[Dict[int, Mail], Dict[int, Mail]]:
|
||||||
prev_unread_msgs = self._unread_msgs[mailbox_idx]
|
prev_unread_msgs = self._unread_msgs[mailbox_idx]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -208,8 +237,9 @@ class MailBackend(Backend):
|
||||||
if msg_id not in unread_msgs
|
if msg_id not in unread_msgs
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_flagged_unflagged_msgs(self, mailbox_idx: int, flagged_msgs: Dict[int, Mail]) \
|
def _get_flagged_unflagged_msgs(
|
||||||
-> Tuple[Dict[int, Mail], Dict[int, Mail]]:
|
self, mailbox_idx: int, flagged_msgs: Dict[int, Mail]
|
||||||
|
) -> Tuple[Dict[int, Mail], Dict[int, Mail]]:
|
||||||
prev_flagged_msgs = self._flagged_msgs[mailbox_idx]
|
prev_flagged_msgs = self._flagged_msgs[mailbox_idx]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -222,21 +252,36 @@ class MailBackend(Backend):
|
||||||
if msg_id not in flagged_msgs
|
if msg_id not in flagged_msgs
|
||||||
}
|
}
|
||||||
|
|
||||||
def _process_msg_events(self, mailbox_id: int, unread: List[Mail], seen: List[Mail],
|
def _process_msg_events(
|
||||||
flagged: List[Mail], unflagged: List[Mail], last_checked_date: Optional[datetime] = None):
|
self,
|
||||||
|
mailbox_id: int,
|
||||||
|
unread: List[Mail],
|
||||||
|
seen: List[Mail],
|
||||||
|
flagged: List[Mail],
|
||||||
|
unflagged: List[Mail],
|
||||||
|
last_checked_date: Optional[datetime] = None,
|
||||||
|
):
|
||||||
for msg in unread:
|
for msg in unread:
|
||||||
if msg.date and last_checked_date and msg.date < last_checked_date:
|
if msg.date and last_checked_date and msg.date < last_checked_date:
|
||||||
continue
|
continue
|
||||||
self.bus.post(MailReceivedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
|
self.bus.post(
|
||||||
|
MailReceivedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
|
||||||
|
)
|
||||||
|
|
||||||
for msg in seen:
|
for msg in seen:
|
||||||
self.bus.post(MailSeenEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
|
self.bus.post(
|
||||||
|
MailSeenEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
|
||||||
|
)
|
||||||
|
|
||||||
for msg in flagged:
|
for msg in flagged:
|
||||||
self.bus.post(MailFlaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
|
self.bus.post(
|
||||||
|
MailFlaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
|
||||||
|
)
|
||||||
|
|
||||||
for msg in unflagged:
|
for msg in unflagged:
|
||||||
self.bus.post(MailUnflaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
|
self.bus.post(
|
||||||
|
MailUnflaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
|
||||||
|
)
|
||||||
|
|
||||||
def _check_mailboxes(self) -> List[Tuple[Dict[int, Mail], Dict[int, Mail]]]:
|
def _check_mailboxes(self) -> List[Tuple[Dict[int, Mail], Dict[int, Mail]]]:
|
||||||
workers = []
|
workers = []
|
||||||
|
@ -245,8 +290,14 @@ class MailBackend(Backend):
|
||||||
|
|
||||||
for mbox in self.mailboxes:
|
for mbox in self.mailboxes:
|
||||||
unread_queue, flagged_queue = [Queue()] * 2
|
unread_queue, flagged_queue = [Queue()] * 2
|
||||||
worker = Thread(target=self._check_thread(unread_queue=unread_queue, flagged_queue=flagged_queue,
|
worker = Thread(
|
||||||
plugin=mbox.plugin, **mbox.args))
|
target=self._check_thread(
|
||||||
|
unread_queue=unread_queue,
|
||||||
|
flagged_queue=flagged_queue,
|
||||||
|
plugin=mbox.plugin,
|
||||||
|
**mbox.args
|
||||||
|
)
|
||||||
|
)
|
||||||
worker.start()
|
worker.start()
|
||||||
workers.append(worker)
|
workers.append(worker)
|
||||||
queues.append((unread_queue, flagged_queue))
|
queues.append((unread_queue, flagged_queue))
|
||||||
|
@ -260,7 +311,11 @@ class MailBackend(Backend):
|
||||||
flagged = flagged_queue.get(timeout=self.timeout)
|
flagged = flagged_queue.get(timeout=self.timeout)
|
||||||
results.append((unread, flagged))
|
results.append((unread, flagged))
|
||||||
except Empty:
|
except Empty:
|
||||||
self.logger.warning('Checks on mailbox #{} timed out after {} seconds'.format(i + 1, self.timeout))
|
self.logger.warning(
|
||||||
|
'Checks on mailbox #{} timed out after {} seconds'.format(
|
||||||
|
i + 1, self.timeout
|
||||||
|
)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -276,16 +331,25 @@ class MailBackend(Backend):
|
||||||
for i, (unread, flagged) in enumerate(results):
|
for i, (unread, flagged) in enumerate(results):
|
||||||
unread_msgs, seen_msgs = self._get_unread_seen_msgs(i, unread)
|
unread_msgs, seen_msgs = self._get_unread_seen_msgs(i, unread)
|
||||||
flagged_msgs, unflagged_msgs = self._get_flagged_unflagged_msgs(i, flagged)
|
flagged_msgs, unflagged_msgs = self._get_flagged_unflagged_msgs(i, flagged)
|
||||||
self._process_msg_events(i, unread=list(unread_msgs.values()), seen=list(seen_msgs.values()),
|
self._process_msg_events(
|
||||||
flagged=list(flagged_msgs.values()), unflagged=list(unflagged_msgs.values()),
|
i,
|
||||||
last_checked_date=mailbox_statuses[i].last_checked_date)
|
unread=list(unread_msgs.values()),
|
||||||
|
seen=list(seen_msgs.values()),
|
||||||
|
flagged=list(flagged_msgs.values()),
|
||||||
|
unflagged=list(unflagged_msgs.values()),
|
||||||
|
last_checked_date=mailbox_statuses[i].last_checked_date,
|
||||||
|
)
|
||||||
|
|
||||||
self._unread_msgs[i] = unread
|
self._unread_msgs[i] = unread
|
||||||
self._flagged_msgs[i] = flagged
|
self._flagged_msgs[i] = flagged
|
||||||
records.append(MailboxStatus(mailbox_id=i,
|
records.append(
|
||||||
unseen_message_ids=json.dumps([msg_id for msg_id in unread.keys()]),
|
MailboxStatus(
|
||||||
flagged_message_ids=json.dumps([msg_id for msg_id in flagged.keys()]),
|
mailbox_id=i,
|
||||||
last_checked_date=datetime.now()))
|
unseen_message_ids=json.dumps(list(unread.keys())),
|
||||||
|
flagged_message_ids=json.dumps(list(flagged.keys())),
|
||||||
|
last_checked_date=datetime.now(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with self._db_lock:
|
with self._db_lock:
|
||||||
session = Session()
|
session = Session()
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Mapping, Type
|
||||||
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint
|
from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.orm import declarative_base
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
entities_registry: Mapping[Type['Entity'], Mapping] = {}
|
entities_registry: Mapping[Type['Entity'], Mapping] = {}
|
||||||
|
@ -24,14 +24,16 @@ class Entity(Base):
|
||||||
type = Column(String, nullable=False, index=True)
|
type = Column(String, nullable=False, index=True)
|
||||||
plugin = Column(String, nullable=False)
|
plugin = Column(String, nullable=False)
|
||||||
data = Column(JSON, default=dict)
|
data = Column(JSON, default=dict)
|
||||||
created_at = Column(DateTime(timezone=False), default=datetime.utcnow(), nullable=False)
|
created_at = Column(
|
||||||
updated_at = Column(DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now())
|
DateTime(timezone=False), default=datetime.utcnow(), nullable=False
|
||||||
|
)
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
UniqueConstraint(external_id, plugin)
|
UniqueConstraint(external_id, plugin)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (Index(name, plugin),)
|
||||||
Index(name, plugin),
|
|
||||||
)
|
|
||||||
|
|
||||||
__mapper_args__ = {
|
__mapper_args__ = {
|
||||||
'polymorphic_identity': __tablename__,
|
'polymorphic_identity': __tablename__,
|
||||||
|
@ -41,13 +43,14 @@ class Entity(Base):
|
||||||
|
|
||||||
def _discover_entity_types():
|
def _discover_entity_types():
|
||||||
from platypush.context import get_plugin
|
from platypush.context import get_plugin
|
||||||
|
|
||||||
logger = get_plugin('logger')
|
logger = get_plugin('logger')
|
||||||
assert logger
|
assert logger
|
||||||
|
|
||||||
for loader, modname, _ in pkgutil.walk_packages(
|
for loader, modname, _ in pkgutil.walk_packages(
|
||||||
path=[str(pathlib.Path(__file__).parent.absolute())],
|
path=[str(pathlib.Path(__file__).parent.absolute())],
|
||||||
prefix=__package__ + '.',
|
prefix=__package__ + '.',
|
||||||
onerror=lambda _: None
|
onerror=lambda _: None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
mod_loader = loader.find_module(modname) # type: ignore
|
mod_loader = loader.find_module(modname) # type: ignore
|
||||||
|
@ -65,9 +68,9 @@ def _discover_entity_types():
|
||||||
|
|
||||||
def init_entities_db():
|
def init_entities_db():
|
||||||
from platypush.context import get_plugin
|
from platypush.context import get_plugin
|
||||||
|
|
||||||
_discover_entity_types()
|
_discover_entity_types()
|
||||||
db = get_plugin('db')
|
db = get_plugin('db')
|
||||||
assert db
|
assert db
|
||||||
engine = db.get_engine()
|
engine = db.get_engine()
|
||||||
db.create_all(engine, Base)
|
db.create_all(engine, Base)
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,16 @@ import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, PrimaryKeyConstraint, ForeignKey
|
from sqlalchemy import (
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
create_engine,
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
Column,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
DateTime,
|
||||||
|
PrimaryKeyConstraint,
|
||||||
|
ForeignKey,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||||
from sqlalchemy.sql.expression import func
|
from sqlalchemy.sql.expression import func
|
||||||
|
|
||||||
from platypush.config import Config
|
from platypush.config import Config
|
||||||
|
@ -38,7 +45,8 @@ class LocalMediaSearcher(MediaSearcher):
|
||||||
if not self._db_engine:
|
if not self._db_engine:
|
||||||
self._db_engine = create_engine(
|
self._db_engine = create_engine(
|
||||||
'sqlite:///{}'.format(self.db_file),
|
'sqlite:///{}'.format(self.db_file),
|
||||||
connect_args={'check_same_thread': False})
|
connect_args={'check_same_thread': False},
|
||||||
|
)
|
||||||
|
|
||||||
Base.metadata.create_all(self._db_engine)
|
Base.metadata.create_all(self._db_engine)
|
||||||
Session.configure(bind=self._db_engine)
|
Session.configure(bind=self._db_engine)
|
||||||
|
@ -57,27 +65,30 @@ class LocalMediaSearcher(MediaSearcher):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_last_modify_time(cls, path, recursive=False):
|
def _get_last_modify_time(cls, path, recursive=False):
|
||||||
return max([os.path.getmtime(p) for p, _, _ in os.walk(path)]) \
|
return (
|
||||||
if recursive else os.path.getmtime(path)
|
max([os.path.getmtime(p) for p, _, _ in os.walk(path)])
|
||||||
|
if recursive
|
||||||
|
else os.path.getmtime(path)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _has_directory_changed_since_last_indexing(self, dir_record):
|
def _has_directory_changed_since_last_indexing(cls, dir_record):
|
||||||
if not dir_record.last_indexed_at:
|
if not dir_record.last_indexed_at:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return datetime.datetime.fromtimestamp(
|
return (
|
||||||
self._get_last_modify_time(dir_record.path)) > dir_record.last_indexed_at
|
datetime.datetime.fromtimestamp(cls._get_last_modify_time(dir_record.path))
|
||||||
|
> dir_record.last_indexed_at
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _matches_query(cls, filename, query):
|
def _matches_query(cls, filename, query):
|
||||||
filename = filename.lower()
|
filename = filename.lower()
|
||||||
query_tokens = [_.lower() for _ in re.split(
|
query_tokens = [
|
||||||
cls._filename_separators, query.strip())]
|
_.lower() for _ in re.split(cls._filename_separators, query.strip())
|
||||||
|
]
|
||||||
|
|
||||||
for token in query_tokens:
|
return all(token in filename for token in query_tokens)
|
||||||
if token not in filename:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _sync_token_records(cls, session, *tokens):
|
def _sync_token_records(cls, session, *tokens):
|
||||||
|
@ -85,9 +96,12 @@ class LocalMediaSearcher(MediaSearcher):
|
||||||
if not tokens:
|
if not tokens:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
records = {record.token: record for record in
|
records = {
|
||||||
session.query(MediaToken).filter(
|
record.token: record
|
||||||
MediaToken.token.in_(tokens)).all()}
|
for record in session.query(MediaToken)
|
||||||
|
.filter(MediaToken.token.in_(tokens))
|
||||||
|
.all()
|
||||||
|
}
|
||||||
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
if token in records:
|
if token in records:
|
||||||
|
@ -97,13 +111,11 @@ class LocalMediaSearcher(MediaSearcher):
|
||||||
records[token] = record
|
records[token] = record
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
return session.query(MediaToken).filter(
|
return session.query(MediaToken).filter(MediaToken.token.in_(tokens)).all()
|
||||||
MediaToken.token.in_(tokens)).all()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_file_records(cls, dir_record, session):
|
def _get_file_records(cls, dir_record, session):
|
||||||
return session.query(MediaFile).filter_by(
|
return session.query(MediaFile).filter_by(directory_id=dir_record.id).all()
|
||||||
directory_id=dir_record.id).all()
|
|
||||||
|
|
||||||
def scan(self, media_dir, session=None, dir_record=None):
|
def scan(self, media_dir, session=None, dir_record=None):
|
||||||
"""
|
"""
|
||||||
|
@ -121,17 +133,19 @@ class LocalMediaSearcher(MediaSearcher):
|
||||||
dir_record = self._get_or_create_dir_entry(session, media_dir)
|
dir_record = self._get_or_create_dir_entry(session, media_dir)
|
||||||
|
|
||||||
if not os.path.isdir(media_dir):
|
if not os.path.isdir(media_dir):
|
||||||
self.logger.info('Directory {} is no longer accessible, removing it'.
|
self.logger.info(
|
||||||
format(media_dir))
|
'Directory {} is no longer accessible, removing it'.format(media_dir)
|
||||||
session.query(MediaDirectory) \
|
)
|
||||||
.filter(MediaDirectory.path == media_dir) \
|
session.query(MediaDirectory).filter(
|
||||||
.delete(synchronize_session='fetch')
|
MediaDirectory.path == media_dir
|
||||||
|
).delete(synchronize_session='fetch')
|
||||||
return
|
return
|
||||||
|
|
||||||
stored_file_records = {
|
stored_file_records = {
|
||||||
f.path: f for f in self._get_file_records(dir_record, session)}
|
f.path: f for f in self._get_file_records(dir_record, session)
|
||||||
|
}
|
||||||
|
|
||||||
for path, dirs, files in os.walk(media_dir):
|
for path, _, files in os.walk(media_dir):
|
||||||
for filename in files:
|
for filename in files:
|
||||||
filepath = os.path.join(path, filename)
|
filepath = os.path.join(path, filename)
|
||||||
|
|
||||||
|
@ -142,26 +156,32 @@ class LocalMediaSearcher(MediaSearcher):
|
||||||
del stored_file_records[filepath]
|
del stored_file_records[filepath]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not MediaPlugin.is_video_file(filename) and \
|
if not MediaPlugin.is_video_file(
|
||||||
not MediaPlugin.is_audio_file(filename):
|
filename
|
||||||
|
) and not MediaPlugin.is_audio_file(filename):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.logger.debug('Syncing item {}'.format(filepath))
|
self.logger.debug('Syncing item {}'.format(filepath))
|
||||||
tokens = [_.lower() for _ in re.split(self._filename_separators,
|
tokens = [
|
||||||
filename.strip())]
|
_.lower()
|
||||||
|
for _ in re.split(self._filename_separators, filename.strip())
|
||||||
|
]
|
||||||
|
|
||||||
token_records = self._sync_token_records(session, *tokens)
|
token_records = self._sync_token_records(session, *tokens)
|
||||||
file_record = MediaFile.build(directory_id=dir_record.id,
|
file_record = MediaFile.build(directory_id=dir_record.id, path=filepath)
|
||||||
path=filepath)
|
|
||||||
|
|
||||||
session.add(file_record)
|
session.add(file_record)
|
||||||
session.commit()
|
session.commit()
|
||||||
file_record = session.query(MediaFile).filter_by(
|
file_record = (
|
||||||
directory_id=dir_record.id, path=filepath).one()
|
session.query(MediaFile)
|
||||||
|
.filter_by(directory_id=dir_record.id, path=filepath)
|
||||||
|
.one()
|
||||||
|
)
|
||||||
|
|
||||||
for token_record in token_records:
|
for token_record in token_records:
|
||||||
file_token = MediaFileToken.build(file_id=file_record.id,
|
file_token = MediaFileToken.build(
|
||||||
token_id=token_record.id)
|
file_id=file_record.id, token_id=token_record.id
|
||||||
|
)
|
||||||
session.add(file_token)
|
session.add(file_token)
|
||||||
|
|
||||||
# stored_file_records should now only contain the records of the files
|
# stored_file_records should now only contain the records of the files
|
||||||
|
@ -169,15 +189,20 @@ class LocalMediaSearcher(MediaSearcher):
|
||||||
if stored_file_records:
|
if stored_file_records:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
'Removing references to {} deleted media items from {}'.format(
|
'Removing references to {} deleted media items from {}'.format(
|
||||||
len(stored_file_records), media_dir))
|
len(stored_file_records), media_dir
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
session.query(MediaFile).filter(MediaFile.id.in_(
|
session.query(MediaFile).filter(
|
||||||
[record.id for record in stored_file_records.values()]
|
MediaFile.id.in_([record.id for record in stored_file_records.values()])
|
||||||
)).delete(synchronize_session='fetch')
|
).delete(synchronize_session='fetch')
|
||||||
|
|
||||||
dir_record.last_indexed_at = datetime.datetime.now()
|
dir_record.last_indexed_at = datetime.datetime.now()
|
||||||
self.logger.info('Scanned {} in {} seconds'.format(
|
self.logger.info(
|
||||||
media_dir, int(time.time() - index_start_time)))
|
'Scanned {} in {} seconds'.format(
|
||||||
|
media_dir, int(time.time() - index_start_time)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
@ -197,25 +222,30 @@ class LocalMediaSearcher(MediaSearcher):
|
||||||
dir_record = self._get_or_create_dir_entry(session, media_dir)
|
dir_record = self._get_or_create_dir_entry(session, media_dir)
|
||||||
|
|
||||||
if self._has_directory_changed_since_last_indexing(dir_record):
|
if self._has_directory_changed_since_last_indexing(dir_record):
|
||||||
self.logger.info('{} has changed since last indexing, '.format(
|
self.logger.info(
|
||||||
media_dir) + 're-indexing')
|
'{} has changed since last indexing, '.format(media_dir)
|
||||||
|
+ 're-indexing'
|
||||||
|
)
|
||||||
|
|
||||||
self.scan(media_dir, session=session, dir_record=dir_record)
|
self.scan(media_dir, session=session, dir_record=dir_record)
|
||||||
|
|
||||||
query_tokens = [_.lower() for _ in re.split(
|
query_tokens = [
|
||||||
self._filename_separators, query.strip())]
|
_.lower() for _ in re.split(self._filename_separators, query.strip())
|
||||||
|
]
|
||||||
|
|
||||||
for file_record in session.query(MediaFile.path). \
|
for file_record in (
|
||||||
join(MediaFileToken). \
|
session.query(MediaFile.path)
|
||||||
join(MediaToken). \
|
.join(MediaFileToken)
|
||||||
filter(MediaToken.token.in_(query_tokens)). \
|
.join(MediaToken)
|
||||||
group_by(MediaFile.path). \
|
.filter(MediaToken.token.in_(query_tokens))
|
||||||
having(func.count(MediaFileToken.token_id) >= len(query_tokens)):
|
.group_by(MediaFile.path)
|
||||||
|
.having(func.count(MediaFileToken.token_id) >= len(query_tokens))
|
||||||
|
):
|
||||||
if os.path.isfile(file_record.path):
|
if os.path.isfile(file_record.path):
|
||||||
results[file_record.path] = {
|
results[file_record.path] = {
|
||||||
'url': 'file://' + file_record.path,
|
'url': 'file://' + file_record.path,
|
||||||
'title': os.path.basename(file_record.path),
|
'title': os.path.basename(file_record.path),
|
||||||
'size': os.path.getsize(file_record.path)
|
'size': os.path.getsize(file_record.path),
|
||||||
}
|
}
|
||||||
|
|
||||||
return results.values()
|
return results.values()
|
||||||
|
@ -223,11 +253,12 @@ class LocalMediaSearcher(MediaSearcher):
|
||||||
|
|
||||||
# --- Table definitions
|
# --- Table definitions
|
||||||
|
|
||||||
|
|
||||||
class MediaDirectory(Base):
|
class MediaDirectory(Base):
|
||||||
"""Models the MediaDirectory table"""
|
"""Models the MediaDirectory table"""
|
||||||
|
|
||||||
__tablename__ = 'MediaDirectory'
|
__tablename__ = 'MediaDirectory'
|
||||||
__table_args__ = ({'sqlite_autoincrement': True})
|
__table_args__ = {'sqlite_autoincrement': True}
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
path = Column(String)
|
path = Column(String)
|
||||||
|
@ -246,11 +277,12 @@ class MediaFile(Base):
|
||||||
"""Models the MediaFile table"""
|
"""Models the MediaFile table"""
|
||||||
|
|
||||||
__tablename__ = 'MediaFile'
|
__tablename__ = 'MediaFile'
|
||||||
__table_args__ = ({'sqlite_autoincrement': True})
|
__table_args__ = {'sqlite_autoincrement': True}
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
directory_id = Column(Integer, ForeignKey(
|
directory_id = Column(
|
||||||
'MediaDirectory.id', ondelete='CASCADE'), nullable=False)
|
Integer, ForeignKey('MediaDirectory.id', ondelete='CASCADE'), nullable=False
|
||||||
|
)
|
||||||
path = Column(String, nullable=False, unique=True)
|
path = Column(String, nullable=False, unique=True)
|
||||||
indexed_at = Column(DateTime)
|
indexed_at = Column(DateTime)
|
||||||
|
|
||||||
|
@ -268,7 +300,7 @@ class MediaToken(Base):
|
||||||
"""Models the MediaToken table"""
|
"""Models the MediaToken table"""
|
||||||
|
|
||||||
__tablename__ = 'MediaToken'
|
__tablename__ = 'MediaToken'
|
||||||
__table_args__ = ({'sqlite_autoincrement': True})
|
__table_args__ = {'sqlite_autoincrement': True}
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
token = Column(String, nullable=False, unique=True)
|
token = Column(String, nullable=False, unique=True)
|
||||||
|
@ -286,10 +318,12 @@ class MediaFileToken(Base):
|
||||||
|
|
||||||
__tablename__ = 'MediaFileToken'
|
__tablename__ = 'MediaFileToken'
|
||||||
|
|
||||||
file_id = Column(Integer, ForeignKey('MediaFile.id', ondelete='CASCADE'),
|
file_id = Column(
|
||||||
nullable=False)
|
Integer, ForeignKey('MediaFile.id', ondelete='CASCADE'), nullable=False
|
||||||
token_id = Column(Integer, ForeignKey('MediaToken.id', ondelete='CASCADE'),
|
)
|
||||||
nullable=False)
|
token_id = Column(
|
||||||
|
Integer, ForeignKey('MediaToken.id', ondelete='CASCADE'), nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
__table_args__ = (PrimaryKeyConstraint(file_id, token_id), {})
|
__table_args__ = (PrimaryKeyConstraint(file_id, token_id), {})
|
||||||
|
|
||||||
|
@ -301,4 +335,5 @@ class MediaFileToken(Base):
|
||||||
record.token_id = token_id
|
record.token_id = token_id
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
|
@ -13,11 +13,13 @@ except ImportError:
|
||||||
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
|
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||||
from sqlalchemy.orm import make_transient
|
from sqlalchemy.orm import make_transient, declarative_base
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
from platypush.context import get_plugin
|
from platypush.context import get_plugin
|
||||||
from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException
|
from platypush.exceptions.user import (
|
||||||
|
InvalidJWTTokenException,
|
||||||
|
InvalidCredentialsException,
|
||||||
|
)
|
||||||
from platypush.utils import get_or_generate_jwt_rsa_key_pair
|
from platypush.utils import get_or_generate_jwt_rsa_key_pair
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
@ -68,8 +70,12 @@ class UserManager:
|
||||||
if user:
|
if user:
|
||||||
raise NameError('The user {} already exists'.format(username))
|
raise NameError('The user {} already exists'.format(username))
|
||||||
|
|
||||||
record = User(username=username, password=self._encrypt_password(password),
|
record = User(
|
||||||
created_at=datetime.datetime.utcnow(), **kwargs)
|
username=username,
|
||||||
|
password=self._encrypt_password(password),
|
||||||
|
created_at=datetime.datetime.utcnow(),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
session.add(record)
|
session.add(record)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
@ -93,10 +99,16 @@ class UserManager:
|
||||||
|
|
||||||
def authenticate_user_session(self, session_token):
|
def authenticate_user_session(self, session_token):
|
||||||
with self.db.get_session() as session:
|
with self.db.get_session() as session:
|
||||||
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
|
user_session = (
|
||||||
|
session.query(UserSession)
|
||||||
|
.filter_by(session_token=session_token)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not user_session or (
|
if not user_session or (
|
||||||
user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()):
|
user_session.expires_at
|
||||||
|
and user_session.expires_at < datetime.datetime.utcnow()
|
||||||
|
):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
user = session.query(User).filter_by(user_id=user_session.user_id).first()
|
user = session.query(User).filter_by(user_id=user_session.user_id).first()
|
||||||
|
@ -108,7 +120,9 @@ class UserManager:
|
||||||
if not user:
|
if not user:
|
||||||
raise NameError('No such user: {}'.format(username))
|
raise NameError('No such user: {}'.format(username))
|
||||||
|
|
||||||
user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all()
|
user_sessions = (
|
||||||
|
session.query(UserSession).filter_by(user_id=user.user_id).all()
|
||||||
|
)
|
||||||
for user_session in user_sessions:
|
for user_session in user_sessions:
|
||||||
session.delete(user_session)
|
session.delete(user_session)
|
||||||
|
|
||||||
|
@ -118,7 +132,11 @@ class UserManager:
|
||||||
|
|
||||||
def delete_user_session(self, session_token):
|
def delete_user_session(self, session_token):
|
||||||
with self.db.get_session() as session:
|
with self.db.get_session() as session:
|
||||||
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
|
user_session = (
|
||||||
|
session.query(UserSession)
|
||||||
|
.filter_by(session_token=session_token)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not user_session:
|
if not user_session:
|
||||||
return False
|
return False
|
||||||
|
@ -134,14 +152,18 @@ class UserManager:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if expires_at:
|
if expires_at:
|
||||||
if isinstance(expires_at, int) or isinstance(expires_at, float):
|
if isinstance(expires_at, (int, float)):
|
||||||
expires_at = datetime.datetime.fromtimestamp(expires_at)
|
expires_at = datetime.datetime.fromtimestamp(expires_at)
|
||||||
elif isinstance(expires_at, str):
|
elif isinstance(expires_at, str):
|
||||||
expires_at = datetime.datetime.fromisoformat(expires_at)
|
expires_at = datetime.datetime.fromisoformat(expires_at)
|
||||||
|
|
||||||
user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(),
|
user_session = UserSession(
|
||||||
csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(),
|
user_id=user.user_id,
|
||||||
expires_at=expires_at)
|
session_token=self.generate_session_token(),
|
||||||
|
csrf_token=self.generate_session_token(),
|
||||||
|
created_at=datetime.datetime.utcnow(),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
session.add(user_session)
|
session.add(user_session)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
@ -179,9 +201,19 @@ class UserManager:
|
||||||
:param session_token: Session token.
|
:param session_token: Session token.
|
||||||
"""
|
"""
|
||||||
with self.db.get_session() as session:
|
with self.db.get_session() as session:
|
||||||
return session.query(User).join(UserSession).filter_by(session_token=session_token).first()
|
return (
|
||||||
|
session.query(User)
|
||||||
|
.join(UserSession)
|
||||||
|
.filter_by(session_token=session_token)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str:
|
def generate_jwt_token(
|
||||||
|
self,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
expires_at: Optional[datetime.datetime] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create a user JWT token for API usage.
|
Create a user JWT token for API usage.
|
||||||
|
|
||||||
|
@ -256,7 +288,7 @@ class User(Base):
|
||||||
"""Models the User table"""
|
"""Models the User table"""
|
||||||
|
|
||||||
__tablename__ = 'user'
|
__tablename__ = 'user'
|
||||||
__table_args__ = ({'sqlite_autoincrement': True})
|
__table_args__ = {'sqlite_autoincrement': True}
|
||||||
|
|
||||||
user_id = Column(Integer, primary_key=True)
|
user_id = Column(Integer, primary_key=True)
|
||||||
username = Column(String, unique=True, nullable=False)
|
username = Column(String, unique=True, nullable=False)
|
||||||
|
@ -268,7 +300,7 @@ class UserSession(Base):
|
||||||
"""Models the UserSession table"""
|
"""Models the UserSession table"""
|
||||||
|
|
||||||
__tablename__ = 'user_session'
|
__tablename__ = 'user_session'
|
||||||
__table_args__ = ({'sqlite_autoincrement': True})
|
__table_args__ = {'sqlite_autoincrement': True}
|
||||||
|
|
||||||
session_id = Column(Integer, primary_key=True)
|
session_id = Column(Integer, primary_key=True)
|
||||||
session_token = Column(String, unique=True, nullable=False)
|
session_token = Column(String, unique=True, nullable=False)
|
||||||
|
|
Loading…
Reference in a new issue