Replaced deprecated sqlalchemy.ext.declarative with sqlalchemy.orm

This commit is contained in:
Fabio Manganiello 2022-04-05 22:47:44 +02:00
parent 4b7eeaa4ed
commit 8a70f1d38e
Signed by: blacklight
GPG Key ID: D90FBA7F76362774
7 changed files with 540 additions and 252 deletions

View File

@ -3,8 +3,7 @@ import os
from typing import Optional, Union, List, Dict, Any
from sqlalchemy import create_engine, Column, Integer, String, DateTime
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from platypush.backend import Backend
from platypush.config import Config
@ -17,10 +16,10 @@ Session = scoped_session(sessionmaker())
class Covid19Update(Base):
""" Models the Covid19Data table """
"""Models the Covid19Data table"""
__tablename__ = 'covid19data'
__table_args__ = ({'sqlite_autoincrement': True})
__table_args__ = {'sqlite_autoincrement': True}
country = Column(String, primary_key=True)
confirmed = Column(Integer, nullable=False, default=0)
@ -40,7 +39,12 @@ class Covid19Backend(Backend):
"""
# noinspection PyProtectedMember
def __init__(self, country: Optional[Union[str, List[str]]], poll_seconds: Optional[float] = 3600.0, **kwargs):
def __init__(
self,
country: Optional[Union[str, List[str]]],
poll_seconds: Optional[float] = 3600.0,
**kwargs
):
"""
:param country: Default country (or list of countries) to retrieve the stats for. It can either be the full
country name or the country code. Special values:
@ -56,7 +60,9 @@ class Covid19Backend(Backend):
super().__init__(poll_seconds=poll_seconds, **kwargs)
self._plugin: Covid19Plugin = get_plugin('covid19')
self.country: List[str] = self._plugin._get_countries(country)
self.workdir = os.path.join(os.path.expanduser(Config.get('workdir')), 'covid19')
self.workdir = os.path.join(
os.path.expanduser(Config.get('workdir')), 'covid19'
)
self.dbfile = os.path.join(self.workdir, 'data.db')
os.makedirs(self.workdir, exist_ok=True)
@ -67,22 +73,30 @@ class Covid19Backend(Backend):
self.logger.info('Stopped Covid19 backend')
def _process_update(self, summary: Dict[str, Any], session: Session):
update_time = datetime.datetime.fromisoformat(summary['Date'].replace('Z', '+00:00'))
update_time = datetime.datetime.fromisoformat(
summary['Date'].replace('Z', '+00:00')
)
self.bus.post(Covid19UpdateEvent(
country=summary['Country'],
country_code=summary['CountryCode'],
confirmed=summary['TotalConfirmed'],
deaths=summary['TotalDeaths'],
recovered=summary['TotalRecovered'],
update_time=update_time,
))
self.bus.post(
Covid19UpdateEvent(
country=summary['Country'],
country_code=summary['CountryCode'],
confirmed=summary['TotalConfirmed'],
deaths=summary['TotalDeaths'],
recovered=summary['TotalRecovered'],
update_time=update_time,
)
)
session.merge(Covid19Update(country=summary['CountryCode'],
confirmed=summary['TotalConfirmed'],
deaths=summary['TotalDeaths'],
recovered=summary['TotalRecovered'],
last_updated_at=update_time))
session.merge(
Covid19Update(
country=summary['CountryCode'],
confirmed=summary['TotalConfirmed'],
deaths=summary['TotalDeaths'],
recovered=summary['TotalRecovered'],
last_updated_at=update_time,
)
)
def loop(self):
# noinspection PyUnresolvedReferences
@ -90,23 +104,30 @@ class Covid19Backend(Backend):
if not summaries:
return
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
engine = create_engine(
'sqlite:///{}'.format(self.dbfile),
connect_args={'check_same_thread': False},
)
Base.metadata.create_all(engine)
Session.configure(bind=engine)
session = Session()
last_records = {
record.country: record
for record in session.query(Covid19Update).filter(Covid19Update.country.in_(self.country)).all()
for record in session.query(Covid19Update)
.filter(Covid19Update.country.in_(self.country))
.all()
}
for summary in summaries:
country = summary['CountryCode']
last_record = last_records.get(country)
if not last_record or \
summary['TotalConfirmed'] != last_record.confirmed or \
summary['TotalDeaths'] != last_record.deaths or \
summary['TotalRecovered'] != last_record.recovered:
if (
not last_record
or summary['TotalConfirmed'] != last_record.confirmed
or summary['TotalDeaths'] != last_record.deaths
or summary['TotalRecovered'] != last_record.recovered
):
self._process_update(summary=summary, session=session)
session.commit()

View File

@ -6,15 +6,28 @@ from typing import Optional, List
import requests
from sqlalchemy import create_engine, Column, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from platypush.backend import Backend
from platypush.config import Config
from platypush.message.event.github import GithubPushEvent, GithubCommitCommentEvent, GithubCreateEvent, \
GithubDeleteEvent, GithubEvent, GithubForkEvent, GithubWikiEvent, GithubIssueCommentEvent, GithubIssueEvent, \
GithubMemberEvent, GithubPublicEvent, GithubPullRequestEvent, GithubPullRequestReviewCommentEvent, \
GithubReleaseEvent, GithubSponsorshipEvent, GithubWatchEvent
from platypush.message.event.github import (
GithubPushEvent,
GithubCommitCommentEvent,
GithubCreateEvent,
GithubDeleteEvent,
GithubEvent,
GithubForkEvent,
GithubWikiEvent,
GithubIssueCommentEvent,
GithubIssueEvent,
GithubMemberEvent,
GithubPublicEvent,
GithubPullRequestEvent,
GithubPullRequestReviewCommentEvent,
GithubReleaseEvent,
GithubSponsorshipEvent,
GithubWatchEvent,
)
Base = declarative_base()
Session = scoped_session(sessionmaker())
@ -71,8 +84,17 @@ class GithubBackend(Backend):
_base_url = 'https://api.github.com'
def __init__(self, user: str, user_token: str, repos: Optional[List[str]] = None, org: Optional[str] = None,
poll_seconds: int = 60, max_events_per_scan: Optional[int] = 10, *args, **kwargs):
def __init__(
self,
user: str,
user_token: str,
repos: Optional[List[str]] = None,
org: Optional[str] = None,
poll_seconds: int = 60,
max_events_per_scan: Optional[int] = 10,
*args,
**kwargs
):
"""
If neither ``repos`` nor ``org`` is specified then the backend will monitor all new events on user level.
@ -102,17 +124,23 @@ class GithubBackend(Backend):
def _request(self, uri: str, method: str = 'get') -> dict:
method = getattr(requests, method.lower())
return method(self._base_url + uri, auth=(self.user, self.user_token),
headers={'Accept': 'application/vnd.github.v3+json'}).json()
return method(
self._base_url + uri,
auth=(self.user, self.user_token),
headers={'Accept': 'application/vnd.github.v3+json'},
).json()
def _init_db(self):
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
engine = create_engine(
'sqlite:///{}'.format(self.dbfile),
connect_args={'check_same_thread': False},
)
Base.metadata.create_all(engine)
Session.configure(bind=engine)
@staticmethod
def _to_datetime(time_string: str) -> datetime.datetime:
""" Convert ISO 8061 string format with leading 'Z' into something understandable by Python """
"""Convert ISO 8061 string format with leading 'Z' into something understandable by Python"""
return datetime.datetime.fromisoformat(time_string[:-1] + '+00:00')
@staticmethod
@ -128,7 +156,11 @@ class GithubBackend(Backend):
def _get_last_event_time(self, uri: str):
with self.db_lock:
record = self._get_or_create_resource(uri=uri, session=Session())
return record.last_updated_at.replace(tzinfo=datetime.timezone.utc) if record.last_updated_at else None
return (
record.last_updated_at.replace(tzinfo=datetime.timezone.utc)
if record.last_updated_at
else None
)
def _update_last_event_time(self, uri: str, last_updated_at: datetime.datetime):
with self.db_lock:
@ -158,9 +190,18 @@ class GithubBackend(Backend):
'WatchEvent': GithubWatchEvent,
}
event_type = event_mapping[event['type']] if event['type'] in event_mapping else GithubEvent
return event_type(event_type=event['type'], actor=event['actor'], repo=event.get('repo', {}),
payload=event['payload'], created_at=cls._to_datetime(event['created_at']))
event_type = (
event_mapping[event['type']]
if event['type'] in event_mapping
else GithubEvent
)
return event_type(
event_type=event['type'],
actor=event['actor'],
repo=event.get('repo', {}),
payload=event['payload'],
created_at=cls._to_datetime(event['created_at']),
)
def _events_monitor(self, uri: str, method: str = 'get'):
def thread():
@ -175,7 +216,10 @@ class GithubBackend(Backend):
fired_events = []
for event in events:
if self.max_events_per_scan and len(fired_events) >= self.max_events_per_scan:
if (
self.max_events_per_scan
and len(fired_events) >= self.max_events_per_scan
):
break
event_time = self._to_datetime(event['created_at'])
@ -189,14 +233,19 @@ class GithubBackend(Backend):
for event in fired_events:
self.bus.post(event)
self._update_last_event_time(uri=uri, last_updated_at=new_last_event_time)
self._update_last_event_time(
uri=uri, last_updated_at=new_last_event_time
)
except Exception as e:
self.logger.warning('Encountered exception while fetching events from {}: {}'.format(
uri, str(e)))
self.logger.warning(
'Encountered exception while fetching events from {}: {}'.format(
uri, str(e)
)
)
self.logger.exception(e)
finally:
if self.wait_stop(timeout=self.poll_seconds):
break
if self.wait_stop(timeout=self.poll_seconds):
break
return thread
@ -206,12 +255,30 @@ class GithubBackend(Backend):
if self.repos:
for repo in self.repos:
monitors.append(threading.Thread(target=self._events_monitor('/networks/{repo}/events'.format(repo=repo))))
monitors.append(
threading.Thread(
target=self._events_monitor(
'/networks/{repo}/events'.format(repo=repo)
)
)
)
if self.org:
monitors.append(threading.Thread(target=self._events_monitor('/orgs/{org}/events'.format(org=self.org))))
monitors.append(
threading.Thread(
target=self._events_monitor(
'/orgs/{org}/events'.format(org=self.org)
)
)
)
if not (self.repos or self.org):
monitors.append(threading.Thread(target=self._events_monitor('/users/{user}/events'.format(user=self.user))))
monitors.append(
threading.Thread(
target=self._events_monitor(
'/users/{user}/events'.format(user=self.user)
)
)
)
for monitor in monitors:
monitor.start()
@ -222,4 +289,5 @@ class GithubBackend(Backend):
self.logger.info('Github backend terminated')
# vim:sw=4:ts=4:et:

View File

@ -2,11 +2,17 @@ import datetime
import enum
import os
from sqlalchemy import create_engine, Column, Integer, String, DateTime, \
Enum, ForeignKey
from sqlalchemy import (
create_engine,
Column,
Integer,
String,
DateTime,
Enum,
ForeignKey,
)
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from sqlalchemy.sql.expression import func
from platypush.backend.http.request import HttpRequest
@ -44,18 +50,31 @@ class RssUpdates(HttpRequest):
"""
user_agent = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) ' + \
'Chrome/62.0.3202.94 Safari/537.36'
user_agent = (
'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) '
+ 'Chrome/62.0.3202.94 Safari/537.36'
)
def __init__(self, url, title=None, headers=None, params=None, max_entries=None,
extract_content=False, digest_format=None, user_agent: str = user_agent,
body_style: str = 'font-size: 22px; ' +
'font-family: "Merriweather", Georgia, "Times New Roman", Times, serif;',
title_style: str = 'margin-top: 30px',
subtitle_style: str = 'margin-top: 10px; page-break-after: always',
article_title_style: str = 'page-break-before: always',
article_link_style: str = 'color: #555; text-decoration: none; border-bottom: 1px dotted',
article_content_style: str = '', *argv, **kwargs):
def __init__(
self,
url,
title=None,
headers=None,
params=None,
max_entries=None,
extract_content=False,
digest_format=None,
user_agent: str = user_agent,
body_style: str = 'font-size: 22px; '
+ 'font-family: "Merriweather", Georgia, "Times New Roman", Times, serif;',
title_style: str = 'margin-top: 30px',
subtitle_style: str = 'margin-top: 10px; page-break-after: always',
article_title_style: str = 'page-break-before: always',
article_link_style: str = 'color: #555; text-decoration: none; border-bottom: 1px dotted',
article_content_style: str = '',
*argv,
**kwargs,
):
"""
:param url: URL to the RSS feed to be monitored.
:param title: Optional title for the feed.
@ -91,7 +110,9 @@ class RssUpdates(HttpRequest):
# If true, then the http.webpage plugin will be used to parse the content
self.extract_content = extract_content
self.digest_format = digest_format.lower() if digest_format else None # Supported formats: html, pdf
self.digest_format = (
digest_format.lower() if digest_format else None
) # Supported formats: html, pdf
os.makedirs(os.path.expanduser(os.path.dirname(self.dbfile)), exist_ok=True)
@ -119,7 +140,11 @@ class RssUpdates(HttpRequest):
@staticmethod
def _get_latest_update(session, source_id):
return session.query(func.max(FeedEntry.published)).filter_by(source_id=source_id).scalar()
return (
session.query(func.max(FeedEntry.published))
.filter_by(source_id=source_id)
.scalar()
)
def _parse_entry_content(self, link):
self.logger.info('Extracting content from {}'.format(link))
@ -130,14 +155,20 @@ class RssUpdates(HttpRequest):
errors = response.errors
if not output:
self.logger.warning('Mercury parser error: {}'.format(errors or '[unknown error]'))
self.logger.warning(
'Mercury parser error: {}'.format(errors or '[unknown error]')
)
return
return output.get('content')
def get_new_items(self, response):
import feedparser
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
engine = create_engine(
'sqlite:///{}'.format(self.dbfile),
connect_args={'check_same_thread': False},
)
Base.metadata.create_all(engine)
Session.configure(bind=engine)
@ -157,12 +188,16 @@ class RssUpdates(HttpRequest):
content = u'''
<h1 style="{title_style}">{title}</h1>
<h2 style="{subtitle_style}">Feeds digest generated on {creation_date}</h2>'''.\
format(title_style=self.title_style, title=self.title, subtitle_style=self.subtitle_style,
creation_date=datetime.datetime.now().strftime('%d %B %Y, %H:%M'))
<h2 style="{subtitle_style}">Feeds digest generated on {creation_date}</h2>'''.format(
title_style=self.title_style,
title=self.title,
subtitle_style=self.subtitle_style,
creation_date=datetime.datetime.now().strftime('%d %B %Y, %H:%M'),
)
self.logger.info('Parsed {:d} items from RSS feed <{}>'
.format(len(feed.entries), self.url))
self.logger.info(
'Parsed {:d} items from RSS feed <{}>'.format(len(feed.entries), self.url)
)
for entry in feed.entries:
if not entry.published_parsed:
@ -171,9 +206,10 @@ class RssUpdates(HttpRequest):
try:
entry_timestamp = datetime.datetime(*entry.published_parsed[:6])
if latest_update is None \
or entry_timestamp > latest_update:
self.logger.info('Processed new item from RSS feed <{}>'.format(self.url))
if latest_update is None or entry_timestamp > latest_update:
self.logger.info(
'Processed new item from RSS feed <{}>'.format(self.url)
)
entry.summary = entry.summary if hasattr(entry, 'summary') else None
if self.extract_content:
@ -188,9 +224,13 @@ class RssUpdates(HttpRequest):
<a href="{link}" target="_blank" style="{article_link_style}">{title}</a>
</h1>
<div class="_parsed-content" style="{article_content_style}">{content}</div>'''.format(
article_title_style=self.article_title_style, article_link_style=self.article_link_style,
article_content_style=self.article_content_style, link=entry.link, title=entry.title,
content=entry.content)
article_title_style=self.article_title_style,
article_link_style=self.article_link_style,
article_content_style=self.article_content_style,
link=entry.link,
title=entry.title,
content=entry.content,
)
e = {
'entry_id': entry.id,
@ -207,21 +247,32 @@ class RssUpdates(HttpRequest):
if self.max_entries and len(entries) > self.max_entries:
break
except Exception as e:
self.logger.warning('Exception encountered while parsing RSS ' +
'RSS feed {}: {}'.format(entry.link, str(e)))
self.logger.warning(
'Exception encountered while parsing RSS '
+ f'RSS feed {entry.link}: {e}'
)
self.logger.exception(e)
source_record.last_updated_at = parse_start_time
digest_filename = None
if entries:
self.logger.info('Parsed {} new entries from the RSS feed {}'.format(
len(entries), self.title))
self.logger.info(
'Parsed {} new entries from the RSS feed {}'.format(
len(entries), self.title
)
)
if self.digest_format:
digest_filename = os.path.join(self.workdir, 'cache', '{}_{}.{}'.format(
datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'),
self.title, self.digest_format))
digest_filename = os.path.join(
self.workdir,
'cache',
'{}_{}.{}'.format(
datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'),
self.title,
self.digest_format,
),
)
os.makedirs(os.path.dirname(digest_filename), exist_ok=True)
@ -233,12 +284,15 @@ class RssUpdates(HttpRequest):
</head>
<body style="{body_style}">{content}</body>
</html>
'''.format(title=self.title, body_style=self.body_style, content=content)
'''.format(
title=self.title, body_style=self.body_style, content=content
)
with open(digest_filename, 'w', encoding='utf-8') as f:
f.write(content)
elif self.digest_format == 'pdf':
from weasyprint import HTML, CSS
try:
from weasyprint.fonts import FontConfiguration
except ImportError:
@ -246,37 +300,47 @@ class RssUpdates(HttpRequest):
body_style = 'body { ' + self.body_style + ' }'
font_config = FontConfiguration()
css = [CSS('https://fonts.googleapis.com/css?family=Merriweather'),
CSS(string=body_style, font_config=font_config)]
css = [
CSS('https://fonts.googleapis.com/css?family=Merriweather'),
CSS(string=body_style, font_config=font_config),
]
HTML(string=content).write_pdf(digest_filename, stylesheets=css)
else:
raise RuntimeError('Unsupported format: {}. Supported formats: ' +
'html or pdf'.format(self.digest_format))
raise RuntimeError(
f'Unsupported format: {self.digest_format}. Supported formats: html, pdf'
)
digest_entry = FeedDigest(source_id=source_record.id,
format=self.digest_format,
filename=digest_filename)
digest_entry = FeedDigest(
source_id=source_record.id,
format=self.digest_format,
filename=digest_filename,
)
session.add(digest_entry)
self.logger.info('{} digest ready: {}'.format(self.digest_format, digest_filename))
self.logger.info(
'{} digest ready: {}'.format(self.digest_format, digest_filename)
)
session.commit()
self.logger.info('Parsing RSS feed {}: completed'.format(self.title))
return NewFeedEvent(request=dict(self), response=entries,
source_id=source_record.id,
source_title=source_record.title,
title=self.title,
digest_format=self.digest_format,
digest_filename=digest_filename)
return NewFeedEvent(
request=dict(self),
response=entries,
source_id=source_record.id,
source_title=source_record.title,
title=self.title,
digest_format=self.digest_format,
digest_filename=digest_filename,
)
class FeedSource(Base):
""" Models the FeedSource table, containing RSS sources to be parsed """
"""Models the FeedSource table, containing RSS sources to be parsed"""
__tablename__ = 'FeedSource'
__table_args__ = ({'sqlite_autoincrement': True})
__table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True)
title = Column(String)
@ -285,10 +349,10 @@ class FeedSource(Base):
class FeedEntry(Base):
""" Models the FeedEntry table, which contains RSS entries """
"""Models the FeedEntry table, which contains RSS entries"""
__tablename__ = 'FeedEntry'
__table_args__ = ({'sqlite_autoincrement': True})
__table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True)
entry_id = Column(String)
@ -301,15 +365,15 @@ class FeedEntry(Base):
class FeedDigest(Base):
""" Models the FeedDigest table, containing feed digests either in HTML
or PDF format """
"""Models the FeedDigest table, containing feed digests either in HTML
or PDF format"""
class DigestFormat(enum.Enum):
html = 1
pdf = 2
__tablename__ = 'FeedDigest'
__table_args__ = ({'sqlite_autoincrement': True})
__table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True)
source_id = Column(Integer, ForeignKey('FeedSource.id'), nullable=False)
@ -317,4 +381,5 @@ class FeedDigest(Base):
filename = Column(String, nullable=False)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.utcnow)
# vim:sw=4:ts=4:et:

View File

@ -8,15 +8,18 @@ from queue import Queue, Empty
from threading import Thread, RLock
from typing import List, Dict, Any, Optional, Tuple
from sqlalchemy import create_engine, Column, Integer, String, DateTime
import sqlalchemy.engine as engine
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import engine, create_engine, Column, Integer, String, DateTime
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from platypush.backend import Backend
from platypush.config import Config
from platypush.context import get_plugin
from platypush.message.event.mail import MailReceivedEvent, MailSeenEvent, MailFlaggedEvent, MailUnflaggedEvent
from platypush.message.event.mail import (
MailReceivedEvent,
MailSeenEvent,
MailFlaggedEvent,
MailUnflaggedEvent,
)
from platypush.plugins.mail import MailInPlugin, Mail
# <editor-fold desc="Database tables">
@ -25,7 +28,8 @@ Session = scoped_session(sessionmaker())
class MailboxStatus(Base):
""" Models the MailboxStatus table, containing information about the state of a monitored mailbox. """
"""Models the MailboxStatus table, containing information about the state of a monitored mailbox."""
__tablename__ = 'MailboxStatus'
mailbox_id = Column(Integer, primary_key=True)
@ -64,8 +68,13 @@ class MailBackend(Backend):
"""
def __init__(self, mailboxes: List[Dict[str, Any]], timeout: Optional[int] = 60, poll_seconds: Optional[int] = 60,
**kwargs):
def __init__(
self,
mailboxes: List[Dict[str, Any]],
timeout: Optional[int] = 60,
poll_seconds: Optional[int] = 60,
**kwargs
):
"""
:param mailboxes: List of mailboxes to be monitored. Each mailbox entry contains a ``plugin`` attribute to
identify the :class:`platypush.plugins.mail.MailInPlugin` plugin that will be used (e.g. ``mail.imap``)
@ -128,9 +137,13 @@ class MailBackend(Backend):
# Parse mailboxes
for i, mbox in enumerate(mailboxes):
assert 'plugin' in mbox, 'No plugin attribute specified for mailbox n.{}'.format(i)
assert (
'plugin' in mbox
), 'No plugin attribute specified for mailbox n.{}'.format(i)
plugin = get_plugin(mbox.pop('plugin'))
assert isinstance(plugin, MailInPlugin), '{} is not a MailInPlugin'.format(plugin)
assert isinstance(plugin, MailInPlugin), '{} is not a MailInPlugin'.format(
plugin
)
name = mbox.pop('name') if 'name' in mbox else 'Mailbox #{}'.format(i + 1)
self.mailboxes.append(Mailbox(plugin=plugin, name=name, args=mbox))
@ -144,7 +157,10 @@ class MailBackend(Backend):
# <editor-fold desc="Database methods">
def _db_get_engine(self) -> engine.Engine:
return create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
return create_engine(
'sqlite:///{}'.format(self.dbfile),
connect_args={'check_same_thread': False},
)
def _db_load_mailboxes_status(self) -> None:
mailbox_ids = list(range(len(self.mailboxes)))
@ -153,12 +169,18 @@ class MailBackend(Backend):
session = Session()
records = {
record.mailbox_id: record
for record in session.query(MailboxStatus).filter(MailboxStatus.mailbox_id.in_(mailbox_ids)).all()
for record in session.query(MailboxStatus)
.filter(MailboxStatus.mailbox_id.in_(mailbox_ids))
.all()
}
for mbox_id, mbox in enumerate(self.mailboxes):
for mbox_id, _ in enumerate(self.mailboxes):
if mbox_id not in records:
record = MailboxStatus(mailbox_id=mbox_id, unseen_message_ids='[]', flagged_message_ids='[]')
record = MailboxStatus(
mailbox_id=mbox_id,
unseen_message_ids='[]',
flagged_message_ids='[]',
)
session.add(record)
else:
record = records[mbox_id]
@ -170,19 +192,25 @@ class MailBackend(Backend):
session.commit()
def _db_get_mailbox_status(self, mailbox_ids: List[int]) -> Dict[int, MailboxStatus]:
def _db_get_mailbox_status(
self, mailbox_ids: List[int]
) -> Dict[int, MailboxStatus]:
with self._db_lock:
session = Session()
return {
record.mailbox_id: record
for record in session.query(MailboxStatus).filter(MailboxStatus.mailbox_id.in_(mailbox_ids)).all()
for record in session.query(MailboxStatus)
.filter(MailboxStatus.mailbox_id.in_(mailbox_ids))
.all()
}
# </editor-fold>
# <editor-fold desc="Parse unread messages logic">
@staticmethod
def _check_thread(unread_queue: Queue, flagged_queue: Queue, plugin: MailInPlugin, **args):
def _check_thread(
unread_queue: Queue, flagged_queue: Queue, plugin: MailInPlugin, **args
):
def thread():
# noinspection PyUnresolvedReferences
unread = plugin.search_unseen_messages(**args).output
@ -194,8 +222,9 @@ class MailBackend(Backend):
return thread
def _get_unread_seen_msgs(self, mailbox_idx: int, unread_msgs: Dict[int, Mail]) \
-> Tuple[Dict[int, Mail], Dict[int, Mail]]:
def _get_unread_seen_msgs(
self, mailbox_idx: int, unread_msgs: Dict[int, Mail]
) -> Tuple[Dict[int, Mail], Dict[int, Mail]]:
prev_unread_msgs = self._unread_msgs[mailbox_idx]
return {
@ -208,35 +237,51 @@ class MailBackend(Backend):
if msg_id not in unread_msgs
}
def _get_flagged_unflagged_msgs(self, mailbox_idx: int, flagged_msgs: Dict[int, Mail]) \
-> Tuple[Dict[int, Mail], Dict[int, Mail]]:
def _get_flagged_unflagged_msgs(
self, mailbox_idx: int, flagged_msgs: Dict[int, Mail]
) -> Tuple[Dict[int, Mail], Dict[int, Mail]]:
prev_flagged_msgs = self._flagged_msgs[mailbox_idx]
return {
msg_id: flagged_msgs[msg_id]
for msg_id in flagged_msgs
if msg_id not in prev_flagged_msgs
}, {
msg_id: prev_flagged_msgs[msg_id]
for msg_id in prev_flagged_msgs
if msg_id not in flagged_msgs
}
msg_id: flagged_msgs[msg_id]
for msg_id in flagged_msgs
if msg_id not in prev_flagged_msgs
}, {
msg_id: prev_flagged_msgs[msg_id]
for msg_id in prev_flagged_msgs
if msg_id not in flagged_msgs
}
def _process_msg_events(self, mailbox_id: int, unread: List[Mail], seen: List[Mail],
flagged: List[Mail], unflagged: List[Mail], last_checked_date: Optional[datetime] = None):
def _process_msg_events(
self,
mailbox_id: int,
unread: List[Mail],
seen: List[Mail],
flagged: List[Mail],
unflagged: List[Mail],
last_checked_date: Optional[datetime] = None,
):
for msg in unread:
if msg.date and last_checked_date and msg.date < last_checked_date:
continue
self.bus.post(MailReceivedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
self.bus.post(
MailReceivedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
)
for msg in seen:
self.bus.post(MailSeenEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
self.bus.post(
MailSeenEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
)
for msg in flagged:
self.bus.post(MailFlaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
self.bus.post(
MailFlaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
)
for msg in unflagged:
self.bus.post(MailUnflaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
self.bus.post(
MailUnflaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
)
def _check_mailboxes(self) -> List[Tuple[Dict[int, Mail], Dict[int, Mail]]]:
workers = []
@ -245,8 +290,14 @@ class MailBackend(Backend):
for mbox in self.mailboxes:
unread_queue, flagged_queue = [Queue()] * 2
worker = Thread(target=self._check_thread(unread_queue=unread_queue, flagged_queue=flagged_queue,
plugin=mbox.plugin, **mbox.args))
worker = Thread(
target=self._check_thread(
unread_queue=unread_queue,
flagged_queue=flagged_queue,
plugin=mbox.plugin,
**mbox.args
)
)
worker.start()
workers.append(worker)
queues.append((unread_queue, flagged_queue))
@ -260,7 +311,11 @@ class MailBackend(Backend):
flagged = flagged_queue.get(timeout=self.timeout)
results.append((unread, flagged))
except Empty:
self.logger.warning('Checks on mailbox #{} timed out after {} seconds'.format(i + 1, self.timeout))
self.logger.warning(
'Checks on mailbox #{} timed out after {} seconds'.format(
i + 1, self.timeout
)
)
continue
return results
@ -276,16 +331,25 @@ class MailBackend(Backend):
for i, (unread, flagged) in enumerate(results):
unread_msgs, seen_msgs = self._get_unread_seen_msgs(i, unread)
flagged_msgs, unflagged_msgs = self._get_flagged_unflagged_msgs(i, flagged)
self._process_msg_events(i, unread=list(unread_msgs.values()), seen=list(seen_msgs.values()),
flagged=list(flagged_msgs.values()), unflagged=list(unflagged_msgs.values()),
last_checked_date=mailbox_statuses[i].last_checked_date)
self._process_msg_events(
i,
unread=list(unread_msgs.values()),
seen=list(seen_msgs.values()),
flagged=list(flagged_msgs.values()),
unflagged=list(unflagged_msgs.values()),
last_checked_date=mailbox_statuses[i].last_checked_date,
)
self._unread_msgs[i] = unread
self._flagged_msgs[i] = flagged
records.append(MailboxStatus(mailbox_id=i,
unseen_message_ids=json.dumps([msg_id for msg_id in unread.keys()]),
flagged_message_ids=json.dumps([msg_id for msg_id in flagged.keys()]),
last_checked_date=datetime.now()))
records.append(
MailboxStatus(
mailbox_id=i,
unseen_message_ids=json.dumps(list(unread.keys())),
flagged_message_ids=json.dumps(list(flagged.keys())),
last_checked_date=datetime.now(),
)
)
with self._db_lock:
session = Session()

View File

@ -5,7 +5,7 @@ from typing import Mapping, Type
import pkgutil
from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declarative_base
Base = declarative_base()
entities_registry: Mapping[Type['Entity'], Mapping] = {}
@ -24,14 +24,16 @@ class Entity(Base):
type = Column(String, nullable=False, index=True)
plugin = Column(String, nullable=False)
data = Column(JSON, default=dict)
created_at = Column(DateTime(timezone=False), default=datetime.utcnow(), nullable=False)
updated_at = Column(DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now())
created_at = Column(
DateTime(timezone=False), default=datetime.utcnow(), nullable=False
)
updated_at = Column(
DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now()
)
UniqueConstraint(external_id, plugin)
__table_args__ = (
Index(name, plugin),
)
__table_args__ = (Index(name, plugin),)
__mapper_args__ = {
'polymorphic_identity': __tablename__,
@ -41,13 +43,14 @@ class Entity(Base):
def _discover_entity_types():
from platypush.context import get_plugin
logger = get_plugin('logger')
assert logger
for loader, modname, _ in pkgutil.walk_packages(
path=[str(pathlib.Path(__file__).parent.absolute())],
prefix=__package__ + '.',
onerror=lambda _: None
onerror=lambda _: None,
):
try:
mod_loader = loader.find_module(modname) # type: ignore
@ -65,9 +68,9 @@ def _discover_entity_types():
def init_entities_db():
from platypush.context import get_plugin
_discover_entity_types()
db = get_plugin('db')
assert db
engine = db.get_engine()
db.create_all(engine, Base)

View File

@ -3,9 +3,16 @@ import os
import re
import time
from sqlalchemy import create_engine, Column, Integer, String, DateTime, PrimaryKeyConstraint, ForeignKey
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import (
create_engine,
Column,
Integer,
String,
DateTime,
PrimaryKeyConstraint,
ForeignKey,
)
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from sqlalchemy.sql.expression import func
from platypush.config import Config
@ -38,7 +45,8 @@ class LocalMediaSearcher(MediaSearcher):
if not self._db_engine:
self._db_engine = create_engine(
'sqlite:///{}'.format(self.db_file),
connect_args={'check_same_thread': False})
connect_args={'check_same_thread': False},
)
Base.metadata.create_all(self._db_engine)
Session.configure(bind=self._db_engine)
@ -57,27 +65,30 @@ class LocalMediaSearcher(MediaSearcher):
@classmethod
def _get_last_modify_time(cls, path, recursive=False):
return max([os.path.getmtime(p) for p, _, _ in os.walk(path)]) \
if recursive else os.path.getmtime(path)
return (
max([os.path.getmtime(p) for p, _, _ in os.walk(path)])
if recursive
else os.path.getmtime(path)
)
@classmethod
def _has_directory_changed_since_last_indexing(self, dir_record):
def _has_directory_changed_since_last_indexing(cls, dir_record):
if not dir_record.last_indexed_at:
return True
return datetime.datetime.fromtimestamp(
self._get_last_modify_time(dir_record.path)) > dir_record.last_indexed_at
return (
datetime.datetime.fromtimestamp(cls._get_last_modify_time(dir_record.path))
> dir_record.last_indexed_at
)
@classmethod
def _matches_query(cls, filename, query):
filename = filename.lower()
query_tokens = [_.lower() for _ in re.split(
cls._filename_separators, query.strip())]
query_tokens = [
_.lower() for _ in re.split(cls._filename_separators, query.strip())
]
for token in query_tokens:
if token not in filename:
return False
return True
return all(token in filename for token in query_tokens)
@classmethod
def _sync_token_records(cls, session, *tokens):
@ -85,9 +96,12 @@ class LocalMediaSearcher(MediaSearcher):
if not tokens:
return []
records = {record.token: record for record in
session.query(MediaToken).filter(
MediaToken.token.in_(tokens)).all()}
records = {
record.token: record
for record in session.query(MediaToken)
.filter(MediaToken.token.in_(tokens))
.all()
}
for token in tokens:
if token in records:
@ -97,13 +111,11 @@ class LocalMediaSearcher(MediaSearcher):
records[token] = record
session.commit()
return session.query(MediaToken).filter(
MediaToken.token.in_(tokens)).all()
return session.query(MediaToken).filter(MediaToken.token.in_(tokens)).all()
@classmethod
def _get_file_records(cls, dir_record, session):
return session.query(MediaFile).filter_by(
directory_id=dir_record.id).all()
return session.query(MediaFile).filter_by(directory_id=dir_record.id).all()
def scan(self, media_dir, session=None, dir_record=None):
"""
@ -121,17 +133,19 @@ class LocalMediaSearcher(MediaSearcher):
dir_record = self._get_or_create_dir_entry(session, media_dir)
if not os.path.isdir(media_dir):
self.logger.info('Directory {} is no longer accessible, removing it'.
format(media_dir))
session.query(MediaDirectory) \
.filter(MediaDirectory.path == media_dir) \
.delete(synchronize_session='fetch')
self.logger.info(
'Directory {} is no longer accessible, removing it'.format(media_dir)
)
session.query(MediaDirectory).filter(
MediaDirectory.path == media_dir
).delete(synchronize_session='fetch')
return
stored_file_records = {
f.path: f for f in self._get_file_records(dir_record, session)}
f.path: f for f in self._get_file_records(dir_record, session)
}
for path, dirs, files in os.walk(media_dir):
for path, _, files in os.walk(media_dir):
for filename in files:
filepath = os.path.join(path, filename)
@ -142,26 +156,32 @@ class LocalMediaSearcher(MediaSearcher):
del stored_file_records[filepath]
continue
if not MediaPlugin.is_video_file(filename) and \
not MediaPlugin.is_audio_file(filename):
if not MediaPlugin.is_video_file(
filename
) and not MediaPlugin.is_audio_file(filename):
continue
self.logger.debug('Syncing item {}'.format(filepath))
tokens = [_.lower() for _ in re.split(self._filename_separators,
filename.strip())]
tokens = [
_.lower()
for _ in re.split(self._filename_separators, filename.strip())
]
token_records = self._sync_token_records(session, *tokens)
file_record = MediaFile.build(directory_id=dir_record.id,
path=filepath)
file_record = MediaFile.build(directory_id=dir_record.id, path=filepath)
session.add(file_record)
session.commit()
file_record = session.query(MediaFile).filter_by(
directory_id=dir_record.id, path=filepath).one()
file_record = (
session.query(MediaFile)
.filter_by(directory_id=dir_record.id, path=filepath)
.one()
)
for token_record in token_records:
file_token = MediaFileToken.build(file_id=file_record.id,
token_id=token_record.id)
file_token = MediaFileToken.build(
file_id=file_record.id, token_id=token_record.id
)
session.add(file_token)
# stored_file_records should now only contain the records of the files
@ -169,15 +189,20 @@ class LocalMediaSearcher(MediaSearcher):
if stored_file_records:
self.logger.info(
'Removing references to {} deleted media items from {}'.format(
len(stored_file_records), media_dir))
len(stored_file_records), media_dir
)
)
session.query(MediaFile).filter(MediaFile.id.in_(
[record.id for record in stored_file_records.values()]
)).delete(synchronize_session='fetch')
session.query(MediaFile).filter(
MediaFile.id.in_([record.id for record in stored_file_records.values()])
).delete(synchronize_session='fetch')
dir_record.last_indexed_at = datetime.datetime.now()
self.logger.info('Scanned {} in {} seconds'.format(
media_dir, int(time.time() - index_start_time)))
self.logger.info(
'Scanned {} in {} seconds'.format(
media_dir, int(time.time() - index_start_time)
)
)
session.commit()
@ -197,25 +222,30 @@ class LocalMediaSearcher(MediaSearcher):
dir_record = self._get_or_create_dir_entry(session, media_dir)
if self._has_directory_changed_since_last_indexing(dir_record):
self.logger.info('{} has changed since last indexing, '.format(
media_dir) + 're-indexing')
self.logger.info(
'{} has changed since last indexing, '.format(media_dir)
+ 're-indexing'
)
self.scan(media_dir, session=session, dir_record=dir_record)
query_tokens = [_.lower() for _ in re.split(
self._filename_separators, query.strip())]
query_tokens = [
_.lower() for _ in re.split(self._filename_separators, query.strip())
]
for file_record in session.query(MediaFile.path). \
join(MediaFileToken). \
join(MediaToken). \
filter(MediaToken.token.in_(query_tokens)). \
group_by(MediaFile.path). \
having(func.count(MediaFileToken.token_id) >= len(query_tokens)):
for file_record in (
session.query(MediaFile.path)
.join(MediaFileToken)
.join(MediaToken)
.filter(MediaToken.token.in_(query_tokens))
.group_by(MediaFile.path)
.having(func.count(MediaFileToken.token_id) >= len(query_tokens))
):
if os.path.isfile(file_record.path):
results[file_record.path] = {
'url': 'file://' + file_record.path,
'title': os.path.basename(file_record.path),
'size': os.path.getsize(file_record.path)
'size': os.path.getsize(file_record.path),
}
return results.values()
@ -223,11 +253,12 @@ class LocalMediaSearcher(MediaSearcher):
# --- Table definitions
class MediaDirectory(Base):
""" Models the MediaDirectory table """
"""Models the MediaDirectory table"""
__tablename__ = 'MediaDirectory'
__table_args__ = ({'sqlite_autoincrement': True})
__table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True)
path = Column(String)
@ -243,14 +274,15 @@ class MediaDirectory(Base):
class MediaFile(Base):
""" Models the MediaFile table """
"""Models the MediaFile table"""
__tablename__ = 'MediaFile'
__table_args__ = ({'sqlite_autoincrement': True})
__table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True)
directory_id = Column(Integer, ForeignKey(
'MediaDirectory.id', ondelete='CASCADE'), nullable=False)
directory_id = Column(
Integer, ForeignKey('MediaDirectory.id', ondelete='CASCADE'), nullable=False
)
path = Column(String, nullable=False, unique=True)
indexed_at = Column(DateTime)
@ -265,10 +297,10 @@ class MediaFile(Base):
class MediaToken(Base):
""" Models the MediaToken table """
"""Models the MediaToken table"""
__tablename__ = 'MediaToken'
__table_args__ = ({'sqlite_autoincrement': True})
__table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True)
token = Column(String, nullable=False, unique=True)
@ -282,14 +314,16 @@ class MediaToken(Base):
class MediaFileToken(Base):
""" Models the MediaFileToken table """
"""Models the MediaFileToken table"""
__tablename__ = 'MediaFileToken'
file_id = Column(Integer, ForeignKey('MediaFile.id', ondelete='CASCADE'),
nullable=False)
token_id = Column(Integer, ForeignKey('MediaToken.id', ondelete='CASCADE'),
nullable=False)
file_id = Column(
Integer, ForeignKey('MediaFile.id', ondelete='CASCADE'), nullable=False
)
token_id = Column(
Integer, ForeignKey('MediaToken.id', ondelete='CASCADE'), nullable=False
)
__table_args__ = (PrimaryKeyConstraint(file_id, token_id), {})
@ -301,4 +335,5 @@ class MediaFileToken(Base):
record.token_id = token_id
return record
# vim:sw=4:ts=4:et:

View File

@ -13,11 +13,13 @@ except ImportError:
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import make_transient
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import make_transient, declarative_base
from platypush.context import get_plugin
from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException
from platypush.exceptions.user import (
InvalidJWTTokenException,
InvalidCredentialsException,
)
from platypush.utils import get_or_generate_jwt_rsa_key_pair
Base = declarative_base()
@ -68,8 +70,12 @@ class UserManager:
if user:
raise NameError('The user {} already exists'.format(username))
record = User(username=username, password=self._encrypt_password(password),
created_at=datetime.datetime.utcnow(), **kwargs)
record = User(
username=username,
password=self._encrypt_password(password),
created_at=datetime.datetime.utcnow(),
**kwargs
)
session.add(record)
session.commit()
@ -93,10 +99,16 @@ class UserManager:
def authenticate_user_session(self, session_token):
with self.db.get_session() as session:
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
user_session = (
session.query(UserSession)
.filter_by(session_token=session_token)
.first()
)
if not user_session or (
user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()):
user_session.expires_at
and user_session.expires_at < datetime.datetime.utcnow()
):
return None, None
user = session.query(User).filter_by(user_id=user_session.user_id).first()
@ -108,7 +120,9 @@ class UserManager:
if not user:
raise NameError('No such user: {}'.format(username))
user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all()
user_sessions = (
session.query(UserSession).filter_by(user_id=user.user_id).all()
)
for user_session in user_sessions:
session.delete(user_session)
@ -118,7 +132,11 @@ class UserManager:
def delete_user_session(self, session_token):
with self.db.get_session() as session:
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
user_session = (
session.query(UserSession)
.filter_by(session_token=session_token)
.first()
)
if not user_session:
return False
@ -134,14 +152,18 @@ class UserManager:
return None
if expires_at:
if isinstance(expires_at, int) or isinstance(expires_at, float):
if isinstance(expires_at, (int, float)):
expires_at = datetime.datetime.fromtimestamp(expires_at)
elif isinstance(expires_at, str):
expires_at = datetime.datetime.fromisoformat(expires_at)
user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(),
csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(),
expires_at=expires_at)
user_session = UserSession(
user_id=user.user_id,
session_token=self.generate_session_token(),
csrf_token=self.generate_session_token(),
created_at=datetime.datetime.utcnow(),
expires_at=expires_at,
)
session.add(user_session)
session.commit()
@ -179,9 +201,19 @@ class UserManager:
:param session_token: Session token.
"""
with self.db.get_session() as session:
return session.query(User).join(UserSession).filter_by(session_token=session_token).first()
return (
session.query(User)
.join(UserSession)
.filter_by(session_token=session_token)
.first()
)
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str:
def generate_jwt_token(
self,
username: str,
password: str,
expires_at: Optional[datetime.datetime] = None,
) -> str:
"""
Create a user JWT token for API usage.
@ -253,10 +285,10 @@ class UserManager:
class User(Base):
""" Models the User table """
"""Models the User table"""
__tablename__ = 'user'
__table_args__ = ({'sqlite_autoincrement': True})
__table_args__ = {'sqlite_autoincrement': True}
user_id = Column(Integer, primary_key=True)
username = Column(String, unique=True, nullable=False)
@ -265,10 +297,10 @@ class User(Base):
class UserSession(Base):
""" Models the UserSession table """
"""Models the UserSession table"""
__tablename__ = 'user_session'
__table_args__ = ({'sqlite_autoincrement': True})
__table_args__ = {'sqlite_autoincrement': True}
session_id = Column(Integer, primary_key=True)
session_token = Column(String, unique=True, nullable=False)