Replaced deprecated sqlalchemy.ext.declarative with sqlalchemy.orm

This commit is contained in:
Fabio Manganiello 2022-04-05 22:47:44 +02:00
parent 4b7eeaa4ed
commit 8a70f1d38e
Signed by: blacklight
GPG key ID: D90FBA7F76362774
7 changed files with 540 additions and 252 deletions

View file

@ -3,8 +3,7 @@ import os
from typing import Optional, Union, List, Dict, Any from typing import Optional, Union, List, Dict, Any
from sqlalchemy import create_engine, Column, Integer, String, DateTime from sqlalchemy import create_engine, Column, Integer, String, DateTime
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from sqlalchemy.ext.declarative import declarative_base
from platypush.backend import Backend from platypush.backend import Backend
from platypush.config import Config from platypush.config import Config
@ -20,7 +19,7 @@ class Covid19Update(Base):
"""Models the Covid19Data table""" """Models the Covid19Data table"""
__tablename__ = 'covid19data' __tablename__ = 'covid19data'
__table_args__ = ({'sqlite_autoincrement': True}) __table_args__ = {'sqlite_autoincrement': True}
country = Column(String, primary_key=True) country = Column(String, primary_key=True)
confirmed = Column(Integer, nullable=False, default=0) confirmed = Column(Integer, nullable=False, default=0)
@ -40,7 +39,12 @@ class Covid19Backend(Backend):
""" """
# noinspection PyProtectedMember # noinspection PyProtectedMember
def __init__(self, country: Optional[Union[str, List[str]]], poll_seconds: Optional[float] = 3600.0, **kwargs): def __init__(
self,
country: Optional[Union[str, List[str]]],
poll_seconds: Optional[float] = 3600.0,
**kwargs
):
""" """
:param country: Default country (or list of countries) to retrieve the stats for. It can either be the full :param country: Default country (or list of countries) to retrieve the stats for. It can either be the full
country name or the country code. Special values: country name or the country code. Special values:
@ -56,7 +60,9 @@ class Covid19Backend(Backend):
super().__init__(poll_seconds=poll_seconds, **kwargs) super().__init__(poll_seconds=poll_seconds, **kwargs)
self._plugin: Covid19Plugin = get_plugin('covid19') self._plugin: Covid19Plugin = get_plugin('covid19')
self.country: List[str] = self._plugin._get_countries(country) self.country: List[str] = self._plugin._get_countries(country)
self.workdir = os.path.join(os.path.expanduser(Config.get('workdir')), 'covid19') self.workdir = os.path.join(
os.path.expanduser(Config.get('workdir')), 'covid19'
)
self.dbfile = os.path.join(self.workdir, 'data.db') self.dbfile = os.path.join(self.workdir, 'data.db')
os.makedirs(self.workdir, exist_ok=True) os.makedirs(self.workdir, exist_ok=True)
@ -67,22 +73,30 @@ class Covid19Backend(Backend):
self.logger.info('Stopped Covid19 backend') self.logger.info('Stopped Covid19 backend')
def _process_update(self, summary: Dict[str, Any], session: Session): def _process_update(self, summary: Dict[str, Any], session: Session):
update_time = datetime.datetime.fromisoformat(summary['Date'].replace('Z', '+00:00')) update_time = datetime.datetime.fromisoformat(
summary['Date'].replace('Z', '+00:00')
)
self.bus.post(Covid19UpdateEvent( self.bus.post(
Covid19UpdateEvent(
country=summary['Country'], country=summary['Country'],
country_code=summary['CountryCode'], country_code=summary['CountryCode'],
confirmed=summary['TotalConfirmed'], confirmed=summary['TotalConfirmed'],
deaths=summary['TotalDeaths'], deaths=summary['TotalDeaths'],
recovered=summary['TotalRecovered'], recovered=summary['TotalRecovered'],
update_time=update_time, update_time=update_time,
)) )
)
session.merge(Covid19Update(country=summary['CountryCode'], session.merge(
Covid19Update(
country=summary['CountryCode'],
confirmed=summary['TotalConfirmed'], confirmed=summary['TotalConfirmed'],
deaths=summary['TotalDeaths'], deaths=summary['TotalDeaths'],
recovered=summary['TotalRecovered'], recovered=summary['TotalRecovered'],
last_updated_at=update_time)) last_updated_at=update_time,
)
)
def loop(self): def loop(self):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
@ -90,23 +104,30 @@ class Covid19Backend(Backend):
if not summaries: if not summaries:
return return
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False}) engine = create_engine(
'sqlite:///{}'.format(self.dbfile),
connect_args={'check_same_thread': False},
)
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
Session.configure(bind=engine) Session.configure(bind=engine)
session = Session() session = Session()
last_records = { last_records = {
record.country: record record.country: record
for record in session.query(Covid19Update).filter(Covid19Update.country.in_(self.country)).all() for record in session.query(Covid19Update)
.filter(Covid19Update.country.in_(self.country))
.all()
} }
for summary in summaries: for summary in summaries:
country = summary['CountryCode'] country = summary['CountryCode']
last_record = last_records.get(country) last_record = last_records.get(country)
if not last_record or \ if (
summary['TotalConfirmed'] != last_record.confirmed or \ not last_record
summary['TotalDeaths'] != last_record.deaths or \ or summary['TotalConfirmed'] != last_record.confirmed
summary['TotalRecovered'] != last_record.recovered: or summary['TotalDeaths'] != last_record.deaths
or summary['TotalRecovered'] != last_record.recovered
):
self._process_update(summary=summary, session=session) self._process_update(summary=summary, session=session)
session.commit() session.commit()

View file

@ -6,15 +6,28 @@ from typing import Optional, List
import requests import requests
from sqlalchemy import create_engine, Column, String, DateTime from sqlalchemy import create_engine, Column, String, DateTime
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from platypush.backend import Backend from platypush.backend import Backend
from platypush.config import Config from platypush.config import Config
from platypush.message.event.github import GithubPushEvent, GithubCommitCommentEvent, GithubCreateEvent, \ from platypush.message.event.github import (
GithubDeleteEvent, GithubEvent, GithubForkEvent, GithubWikiEvent, GithubIssueCommentEvent, GithubIssueEvent, \ GithubPushEvent,
GithubMemberEvent, GithubPublicEvent, GithubPullRequestEvent, GithubPullRequestReviewCommentEvent, \ GithubCommitCommentEvent,
GithubReleaseEvent, GithubSponsorshipEvent, GithubWatchEvent GithubCreateEvent,
GithubDeleteEvent,
GithubEvent,
GithubForkEvent,
GithubWikiEvent,
GithubIssueCommentEvent,
GithubIssueEvent,
GithubMemberEvent,
GithubPublicEvent,
GithubPullRequestEvent,
GithubPullRequestReviewCommentEvent,
GithubReleaseEvent,
GithubSponsorshipEvent,
GithubWatchEvent,
)
Base = declarative_base() Base = declarative_base()
Session = scoped_session(sessionmaker()) Session = scoped_session(sessionmaker())
@ -71,8 +84,17 @@ class GithubBackend(Backend):
_base_url = 'https://api.github.com' _base_url = 'https://api.github.com'
def __init__(self, user: str, user_token: str, repos: Optional[List[str]] = None, org: Optional[str] = None, def __init__(
poll_seconds: int = 60, max_events_per_scan: Optional[int] = 10, *args, **kwargs): self,
user: str,
user_token: str,
repos: Optional[List[str]] = None,
org: Optional[str] = None,
poll_seconds: int = 60,
max_events_per_scan: Optional[int] = 10,
*args,
**kwargs
):
""" """
If neither ``repos`` nor ``org`` is specified then the backend will monitor all new events on user level. If neither ``repos`` nor ``org`` is specified then the backend will monitor all new events on user level.
@ -102,11 +124,17 @@ class GithubBackend(Backend):
def _request(self, uri: str, method: str = 'get') -> dict: def _request(self, uri: str, method: str = 'get') -> dict:
method = getattr(requests, method.lower()) method = getattr(requests, method.lower())
return method(self._base_url + uri, auth=(self.user, self.user_token), return method(
headers={'Accept': 'application/vnd.github.v3+json'}).json() self._base_url + uri,
auth=(self.user, self.user_token),
headers={'Accept': 'application/vnd.github.v3+json'},
).json()
def _init_db(self): def _init_db(self):
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False}) engine = create_engine(
'sqlite:///{}'.format(self.dbfile),
connect_args={'check_same_thread': False},
)
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
Session.configure(bind=engine) Session.configure(bind=engine)
@ -128,7 +156,11 @@ class GithubBackend(Backend):
def _get_last_event_time(self, uri: str): def _get_last_event_time(self, uri: str):
with self.db_lock: with self.db_lock:
record = self._get_or_create_resource(uri=uri, session=Session()) record = self._get_or_create_resource(uri=uri, session=Session())
return record.last_updated_at.replace(tzinfo=datetime.timezone.utc) if record.last_updated_at else None return (
record.last_updated_at.replace(tzinfo=datetime.timezone.utc)
if record.last_updated_at
else None
)
def _update_last_event_time(self, uri: str, last_updated_at: datetime.datetime): def _update_last_event_time(self, uri: str, last_updated_at: datetime.datetime):
with self.db_lock: with self.db_lock:
@ -158,9 +190,18 @@ class GithubBackend(Backend):
'WatchEvent': GithubWatchEvent, 'WatchEvent': GithubWatchEvent,
} }
event_type = event_mapping[event['type']] if event['type'] in event_mapping else GithubEvent event_type = (
return event_type(event_type=event['type'], actor=event['actor'], repo=event.get('repo', {}), event_mapping[event['type']]
payload=event['payload'], created_at=cls._to_datetime(event['created_at'])) if event['type'] in event_mapping
else GithubEvent
)
return event_type(
event_type=event['type'],
actor=event['actor'],
repo=event.get('repo', {}),
payload=event['payload'],
created_at=cls._to_datetime(event['created_at']),
)
def _events_monitor(self, uri: str, method: str = 'get'): def _events_monitor(self, uri: str, method: str = 'get'):
def thread(): def thread():
@ -175,7 +216,10 @@ class GithubBackend(Backend):
fired_events = [] fired_events = []
for event in events: for event in events:
if self.max_events_per_scan and len(fired_events) >= self.max_events_per_scan: if (
self.max_events_per_scan
and len(fired_events) >= self.max_events_per_scan
):
break break
event_time = self._to_datetime(event['created_at']) event_time = self._to_datetime(event['created_at'])
@ -189,12 +233,17 @@ class GithubBackend(Backend):
for event in fired_events: for event in fired_events:
self.bus.post(event) self.bus.post(event)
self._update_last_event_time(uri=uri, last_updated_at=new_last_event_time) self._update_last_event_time(
uri=uri, last_updated_at=new_last_event_time
)
except Exception as e: except Exception as e:
self.logger.warning('Encountered exception while fetching events from {}: {}'.format( self.logger.warning(
uri, str(e))) 'Encountered exception while fetching events from {}: {}'.format(
uri, str(e)
)
)
self.logger.exception(e) self.logger.exception(e)
finally:
if self.wait_stop(timeout=self.poll_seconds): if self.wait_stop(timeout=self.poll_seconds):
break break
@ -206,12 +255,30 @@ class GithubBackend(Backend):
if self.repos: if self.repos:
for repo in self.repos: for repo in self.repos:
monitors.append(threading.Thread(target=self._events_monitor('/networks/{repo}/events'.format(repo=repo)))) monitors.append(
threading.Thread(
target=self._events_monitor(
'/networks/{repo}/events'.format(repo=repo)
)
)
)
if self.org: if self.org:
monitors.append(threading.Thread(target=self._events_monitor('/orgs/{org}/events'.format(org=self.org)))) monitors.append(
threading.Thread(
target=self._events_monitor(
'/orgs/{org}/events'.format(org=self.org)
)
)
)
if not (self.repos or self.org): if not (self.repos or self.org):
monitors.append(threading.Thread(target=self._events_monitor('/users/{user}/events'.format(user=self.user)))) monitors.append(
threading.Thread(
target=self._events_monitor(
'/users/{user}/events'.format(user=self.user)
)
)
)
for monitor in monitors: for monitor in monitors:
monitor.start() monitor.start()
@ -222,4 +289,5 @@ class GithubBackend(Backend):
self.logger.info('Github backend terminated') self.logger.info('Github backend terminated')
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -2,11 +2,17 @@ import datetime
import enum import enum
import os import os
from sqlalchemy import create_engine, Column, Integer, String, DateTime, \ from sqlalchemy import (
Enum, ForeignKey create_engine,
Column,
Integer,
String,
DateTime,
Enum,
ForeignKey,
)
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql.expression import func from sqlalchemy.sql.expression import func
from platypush.backend.http.request import HttpRequest from platypush.backend.http.request import HttpRequest
@ -44,18 +50,31 @@ class RssUpdates(HttpRequest):
""" """
user_agent = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) ' + \ user_agent = (
'Chrome/62.0.3202.94 Safari/537.36' 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) '
+ 'Chrome/62.0.3202.94 Safari/537.36'
)
def __init__(self, url, title=None, headers=None, params=None, max_entries=None, def __init__(
extract_content=False, digest_format=None, user_agent: str = user_agent, self,
body_style: str = 'font-size: 22px; ' + url,
'font-family: "Merriweather", Georgia, "Times New Roman", Times, serif;', title=None,
headers=None,
params=None,
max_entries=None,
extract_content=False,
digest_format=None,
user_agent: str = user_agent,
body_style: str = 'font-size: 22px; '
+ 'font-family: "Merriweather", Georgia, "Times New Roman", Times, serif;',
title_style: str = 'margin-top: 30px', title_style: str = 'margin-top: 30px',
subtitle_style: str = 'margin-top: 10px; page-break-after: always', subtitle_style: str = 'margin-top: 10px; page-break-after: always',
article_title_style: str = 'page-break-before: always', article_title_style: str = 'page-break-before: always',
article_link_style: str = 'color: #555; text-decoration: none; border-bottom: 1px dotted', article_link_style: str = 'color: #555; text-decoration: none; border-bottom: 1px dotted',
article_content_style: str = '', *argv, **kwargs): article_content_style: str = '',
*argv,
**kwargs,
):
""" """
:param url: URL to the RSS feed to be monitored. :param url: URL to the RSS feed to be monitored.
:param title: Optional title for the feed. :param title: Optional title for the feed.
@ -91,7 +110,9 @@ class RssUpdates(HttpRequest):
# If true, then the http.webpage plugin will be used to parse the content # If true, then the http.webpage plugin will be used to parse the content
self.extract_content = extract_content self.extract_content = extract_content
self.digest_format = digest_format.lower() if digest_format else None # Supported formats: html, pdf self.digest_format = (
digest_format.lower() if digest_format else None
) # Supported formats: html, pdf
os.makedirs(os.path.expanduser(os.path.dirname(self.dbfile)), exist_ok=True) os.makedirs(os.path.expanduser(os.path.dirname(self.dbfile)), exist_ok=True)
@ -119,7 +140,11 @@ class RssUpdates(HttpRequest):
@staticmethod @staticmethod
def _get_latest_update(session, source_id): def _get_latest_update(session, source_id):
return session.query(func.max(FeedEntry.published)).filter_by(source_id=source_id).scalar() return (
session.query(func.max(FeedEntry.published))
.filter_by(source_id=source_id)
.scalar()
)
def _parse_entry_content(self, link): def _parse_entry_content(self, link):
self.logger.info('Extracting content from {}'.format(link)) self.logger.info('Extracting content from {}'.format(link))
@ -130,14 +155,20 @@ class RssUpdates(HttpRequest):
errors = response.errors errors = response.errors
if not output: if not output:
self.logger.warning('Mercury parser error: {}'.format(errors or '[unknown error]')) self.logger.warning(
'Mercury parser error: {}'.format(errors or '[unknown error]')
)
return return
return output.get('content') return output.get('content')
def get_new_items(self, response): def get_new_items(self, response):
import feedparser import feedparser
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
engine = create_engine(
'sqlite:///{}'.format(self.dbfile),
connect_args={'check_same_thread': False},
)
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
Session.configure(bind=engine) Session.configure(bind=engine)
@ -157,12 +188,16 @@ class RssUpdates(HttpRequest):
content = u''' content = u'''
<h1 style="{title_style}">{title}</h1> <h1 style="{title_style}">{title}</h1>
<h2 style="{subtitle_style}">Feeds digest generated on {creation_date}</h2>'''.\ <h2 style="{subtitle_style}">Feeds digest generated on {creation_date}</h2>'''.format(
format(title_style=self.title_style, title=self.title, subtitle_style=self.subtitle_style, title_style=self.title_style,
creation_date=datetime.datetime.now().strftime('%d %B %Y, %H:%M')) title=self.title,
subtitle_style=self.subtitle_style,
creation_date=datetime.datetime.now().strftime('%d %B %Y, %H:%M'),
)
self.logger.info('Parsed {:d} items from RSS feed <{}>' self.logger.info(
.format(len(feed.entries), self.url)) 'Parsed {:d} items from RSS feed <{}>'.format(len(feed.entries), self.url)
)
for entry in feed.entries: for entry in feed.entries:
if not entry.published_parsed: if not entry.published_parsed:
@ -171,9 +206,10 @@ class RssUpdates(HttpRequest):
try: try:
entry_timestamp = datetime.datetime(*entry.published_parsed[:6]) entry_timestamp = datetime.datetime(*entry.published_parsed[:6])
if latest_update is None \ if latest_update is None or entry_timestamp > latest_update:
or entry_timestamp > latest_update: self.logger.info(
self.logger.info('Processed new item from RSS feed <{}>'.format(self.url)) 'Processed new item from RSS feed <{}>'.format(self.url)
)
entry.summary = entry.summary if hasattr(entry, 'summary') else None entry.summary = entry.summary if hasattr(entry, 'summary') else None
if self.extract_content: if self.extract_content:
@ -188,9 +224,13 @@ class RssUpdates(HttpRequest):
<a href="{link}" target="_blank" style="{article_link_style}">{title}</a> <a href="{link}" target="_blank" style="{article_link_style}">{title}</a>
</h1> </h1>
<div class="_parsed-content" style="{article_content_style}">{content}</div>'''.format( <div class="_parsed-content" style="{article_content_style}">{content}</div>'''.format(
article_title_style=self.article_title_style, article_link_style=self.article_link_style, article_title_style=self.article_title_style,
article_content_style=self.article_content_style, link=entry.link, title=entry.title, article_link_style=self.article_link_style,
content=entry.content) article_content_style=self.article_content_style,
link=entry.link,
title=entry.title,
content=entry.content,
)
e = { e = {
'entry_id': entry.id, 'entry_id': entry.id,
@ -207,21 +247,32 @@ class RssUpdates(HttpRequest):
if self.max_entries and len(entries) > self.max_entries: if self.max_entries and len(entries) > self.max_entries:
break break
except Exception as e: except Exception as e:
self.logger.warning('Exception encountered while parsing RSS ' + self.logger.warning(
'RSS feed {}: {}'.format(entry.link, str(e))) 'Exception encountered while parsing RSS '
+ f'RSS feed {entry.link}: {e}'
)
self.logger.exception(e) self.logger.exception(e)
source_record.last_updated_at = parse_start_time source_record.last_updated_at = parse_start_time
digest_filename = None digest_filename = None
if entries: if entries:
self.logger.info('Parsed {} new entries from the RSS feed {}'.format( self.logger.info(
len(entries), self.title)) 'Parsed {} new entries from the RSS feed {}'.format(
len(entries), self.title
)
)
if self.digest_format: if self.digest_format:
digest_filename = os.path.join(self.workdir, 'cache', '{}_{}.{}'.format( digest_filename = os.path.join(
self.workdir,
'cache',
'{}_{}.{}'.format(
datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'), datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'),
self.title, self.digest_format)) self.title,
self.digest_format,
),
)
os.makedirs(os.path.dirname(digest_filename), exist_ok=True) os.makedirs(os.path.dirname(digest_filename), exist_ok=True)
@ -233,12 +284,15 @@ class RssUpdates(HttpRequest):
</head> </head>
<body style="{body_style}">{content}</body> <body style="{body_style}">{content}</body>
</html> </html>
'''.format(title=self.title, body_style=self.body_style, content=content) '''.format(
title=self.title, body_style=self.body_style, content=content
)
with open(digest_filename, 'w', encoding='utf-8') as f: with open(digest_filename, 'w', encoding='utf-8') as f:
f.write(content) f.write(content)
elif self.digest_format == 'pdf': elif self.digest_format == 'pdf':
from weasyprint import HTML, CSS from weasyprint import HTML, CSS
try: try:
from weasyprint.fonts import FontConfiguration from weasyprint.fonts import FontConfiguration
except ImportError: except ImportError:
@ -246,37 +300,47 @@ class RssUpdates(HttpRequest):
body_style = 'body { ' + self.body_style + ' }' body_style = 'body { ' + self.body_style + ' }'
font_config = FontConfiguration() font_config = FontConfiguration()
css = [CSS('https://fonts.googleapis.com/css?family=Merriweather'), css = [
CSS(string=body_style, font_config=font_config)] CSS('https://fonts.googleapis.com/css?family=Merriweather'),
CSS(string=body_style, font_config=font_config),
]
HTML(string=content).write_pdf(digest_filename, stylesheets=css) HTML(string=content).write_pdf(digest_filename, stylesheets=css)
else: else:
raise RuntimeError('Unsupported format: {}. Supported formats: ' + raise RuntimeError(
'html or pdf'.format(self.digest_format)) f'Unsupported format: {self.digest_format}. Supported formats: html, pdf'
)
digest_entry = FeedDigest(source_id=source_record.id, digest_entry = FeedDigest(
source_id=source_record.id,
format=self.digest_format, format=self.digest_format,
filename=digest_filename) filename=digest_filename,
)
session.add(digest_entry) session.add(digest_entry)
self.logger.info('{} digest ready: {}'.format(self.digest_format, digest_filename)) self.logger.info(
'{} digest ready: {}'.format(self.digest_format, digest_filename)
)
session.commit() session.commit()
self.logger.info('Parsing RSS feed {}: completed'.format(self.title)) self.logger.info('Parsing RSS feed {}: completed'.format(self.title))
return NewFeedEvent(request=dict(self), response=entries, return NewFeedEvent(
request=dict(self),
response=entries,
source_id=source_record.id, source_id=source_record.id,
source_title=source_record.title, source_title=source_record.title,
title=self.title, title=self.title,
digest_format=self.digest_format, digest_format=self.digest_format,
digest_filename=digest_filename) digest_filename=digest_filename,
)
class FeedSource(Base): class FeedSource(Base):
"""Models the FeedSource table, containing RSS sources to be parsed""" """Models the FeedSource table, containing RSS sources to be parsed"""
__tablename__ = 'FeedSource' __tablename__ = 'FeedSource'
__table_args__ = ({'sqlite_autoincrement': True}) __table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
title = Column(String) title = Column(String)
@ -288,7 +352,7 @@ class FeedEntry(Base):
"""Models the FeedEntry table, which contains RSS entries""" """Models the FeedEntry table, which contains RSS entries"""
__tablename__ = 'FeedEntry' __tablename__ = 'FeedEntry'
__table_args__ = ({'sqlite_autoincrement': True}) __table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
entry_id = Column(String) entry_id = Column(String)
@ -309,7 +373,7 @@ class FeedDigest(Base):
pdf = 2 pdf = 2
__tablename__ = 'FeedDigest' __tablename__ = 'FeedDigest'
__table_args__ = ({'sqlite_autoincrement': True}) __table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
source_id = Column(Integer, ForeignKey('FeedSource.id'), nullable=False) source_id = Column(Integer, ForeignKey('FeedSource.id'), nullable=False)
@ -317,4 +381,5 @@ class FeedDigest(Base):
filename = Column(String, nullable=False) filename = Column(String, nullable=False)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.utcnow) created_at = Column(DateTime, nullable=False, default=datetime.datetime.utcnow)
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -8,15 +8,18 @@ from queue import Queue, Empty
from threading import Thread, RLock from threading import Thread, RLock
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional, Tuple
from sqlalchemy import create_engine, Column, Integer, String, DateTime from sqlalchemy import engine, create_engine, Column, Integer, String, DateTime
import sqlalchemy.engine as engine from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from platypush.backend import Backend from platypush.backend import Backend
from platypush.config import Config from platypush.config import Config
from platypush.context import get_plugin from platypush.context import get_plugin
from platypush.message.event.mail import MailReceivedEvent, MailSeenEvent, MailFlaggedEvent, MailUnflaggedEvent from platypush.message.event.mail import (
MailReceivedEvent,
MailSeenEvent,
MailFlaggedEvent,
MailUnflaggedEvent,
)
from platypush.plugins.mail import MailInPlugin, Mail from platypush.plugins.mail import MailInPlugin, Mail
# <editor-fold desc="Database tables"> # <editor-fold desc="Database tables">
@ -26,6 +29,7 @@ Session = scoped_session(sessionmaker())
class MailboxStatus(Base): class MailboxStatus(Base):
"""Models the MailboxStatus table, containing information about the state of a monitored mailbox.""" """Models the MailboxStatus table, containing information about the state of a monitored mailbox."""
__tablename__ = 'MailboxStatus' __tablename__ = 'MailboxStatus'
mailbox_id = Column(Integer, primary_key=True) mailbox_id = Column(Integer, primary_key=True)
@ -64,8 +68,13 @@ class MailBackend(Backend):
""" """
def __init__(self, mailboxes: List[Dict[str, Any]], timeout: Optional[int] = 60, poll_seconds: Optional[int] = 60, def __init__(
**kwargs): self,
mailboxes: List[Dict[str, Any]],
timeout: Optional[int] = 60,
poll_seconds: Optional[int] = 60,
**kwargs
):
""" """
:param mailboxes: List of mailboxes to be monitored. Each mailbox entry contains a ``plugin`` attribute to :param mailboxes: List of mailboxes to be monitored. Each mailbox entry contains a ``plugin`` attribute to
identify the :class:`platypush.plugins.mail.MailInPlugin` plugin that will be used (e.g. ``mail.imap``) identify the :class:`platypush.plugins.mail.MailInPlugin` plugin that will be used (e.g. ``mail.imap``)
@ -128,9 +137,13 @@ class MailBackend(Backend):
# Parse mailboxes # Parse mailboxes
for i, mbox in enumerate(mailboxes): for i, mbox in enumerate(mailboxes):
assert 'plugin' in mbox, 'No plugin attribute specified for mailbox n.{}'.format(i) assert (
'plugin' in mbox
), 'No plugin attribute specified for mailbox n.{}'.format(i)
plugin = get_plugin(mbox.pop('plugin')) plugin = get_plugin(mbox.pop('plugin'))
assert isinstance(plugin, MailInPlugin), '{} is not a MailInPlugin'.format(plugin) assert isinstance(plugin, MailInPlugin), '{} is not a MailInPlugin'.format(
plugin
)
name = mbox.pop('name') if 'name' in mbox else 'Mailbox #{}'.format(i + 1) name = mbox.pop('name') if 'name' in mbox else 'Mailbox #{}'.format(i + 1)
self.mailboxes.append(Mailbox(plugin=plugin, name=name, args=mbox)) self.mailboxes.append(Mailbox(plugin=plugin, name=name, args=mbox))
@ -144,7 +157,10 @@ class MailBackend(Backend):
# <editor-fold desc="Database methods"> # <editor-fold desc="Database methods">
def _db_get_engine(self) -> engine.Engine: def _db_get_engine(self) -> engine.Engine:
return create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False}) return create_engine(
'sqlite:///{}'.format(self.dbfile),
connect_args={'check_same_thread': False},
)
def _db_load_mailboxes_status(self) -> None: def _db_load_mailboxes_status(self) -> None:
mailbox_ids = list(range(len(self.mailboxes))) mailbox_ids = list(range(len(self.mailboxes)))
@ -153,12 +169,18 @@ class MailBackend(Backend):
session = Session() session = Session()
records = { records = {
record.mailbox_id: record record.mailbox_id: record
for record in session.query(MailboxStatus).filter(MailboxStatus.mailbox_id.in_(mailbox_ids)).all() for record in session.query(MailboxStatus)
.filter(MailboxStatus.mailbox_id.in_(mailbox_ids))
.all()
} }
for mbox_id, mbox in enumerate(self.mailboxes): for mbox_id, _ in enumerate(self.mailboxes):
if mbox_id not in records: if mbox_id not in records:
record = MailboxStatus(mailbox_id=mbox_id, unseen_message_ids='[]', flagged_message_ids='[]') record = MailboxStatus(
mailbox_id=mbox_id,
unseen_message_ids='[]',
flagged_message_ids='[]',
)
session.add(record) session.add(record)
else: else:
record = records[mbox_id] record = records[mbox_id]
@ -170,19 +192,25 @@ class MailBackend(Backend):
session.commit() session.commit()
def _db_get_mailbox_status(self, mailbox_ids: List[int]) -> Dict[int, MailboxStatus]: def _db_get_mailbox_status(
self, mailbox_ids: List[int]
) -> Dict[int, MailboxStatus]:
with self._db_lock: with self._db_lock:
session = Session() session = Session()
return { return {
record.mailbox_id: record record.mailbox_id: record
for record in session.query(MailboxStatus).filter(MailboxStatus.mailbox_id.in_(mailbox_ids)).all() for record in session.query(MailboxStatus)
.filter(MailboxStatus.mailbox_id.in_(mailbox_ids))
.all()
} }
# </editor-fold> # </editor-fold>
# <editor-fold desc="Parse unread messages logic"> # <editor-fold desc="Parse unread messages logic">
@staticmethod @staticmethod
def _check_thread(unread_queue: Queue, flagged_queue: Queue, plugin: MailInPlugin, **args): def _check_thread(
unread_queue: Queue, flagged_queue: Queue, plugin: MailInPlugin, **args
):
def thread(): def thread():
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
unread = plugin.search_unseen_messages(**args).output unread = plugin.search_unseen_messages(**args).output
@ -194,8 +222,9 @@ class MailBackend(Backend):
return thread return thread
def _get_unread_seen_msgs(self, mailbox_idx: int, unread_msgs: Dict[int, Mail]) \ def _get_unread_seen_msgs(
-> Tuple[Dict[int, Mail], Dict[int, Mail]]: self, mailbox_idx: int, unread_msgs: Dict[int, Mail]
) -> Tuple[Dict[int, Mail], Dict[int, Mail]]:
prev_unread_msgs = self._unread_msgs[mailbox_idx] prev_unread_msgs = self._unread_msgs[mailbox_idx]
return { return {
@ -208,8 +237,9 @@ class MailBackend(Backend):
if msg_id not in unread_msgs if msg_id not in unread_msgs
} }
def _get_flagged_unflagged_msgs(self, mailbox_idx: int, flagged_msgs: Dict[int, Mail]) \ def _get_flagged_unflagged_msgs(
-> Tuple[Dict[int, Mail], Dict[int, Mail]]: self, mailbox_idx: int, flagged_msgs: Dict[int, Mail]
) -> Tuple[Dict[int, Mail], Dict[int, Mail]]:
prev_flagged_msgs = self._flagged_msgs[mailbox_idx] prev_flagged_msgs = self._flagged_msgs[mailbox_idx]
return { return {
@ -222,21 +252,36 @@ class MailBackend(Backend):
if msg_id not in flagged_msgs if msg_id not in flagged_msgs
} }
def _process_msg_events(self, mailbox_id: int, unread: List[Mail], seen: List[Mail], def _process_msg_events(
flagged: List[Mail], unflagged: List[Mail], last_checked_date: Optional[datetime] = None): self,
mailbox_id: int,
unread: List[Mail],
seen: List[Mail],
flagged: List[Mail],
unflagged: List[Mail],
last_checked_date: Optional[datetime] = None,
):
for msg in unread: for msg in unread:
if msg.date and last_checked_date and msg.date < last_checked_date: if msg.date and last_checked_date and msg.date < last_checked_date:
continue continue
self.bus.post(MailReceivedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)) self.bus.post(
MailReceivedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
)
for msg in seen: for msg in seen:
self.bus.post(MailSeenEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)) self.bus.post(
MailSeenEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
)
for msg in flagged: for msg in flagged:
self.bus.post(MailFlaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)) self.bus.post(
MailFlaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
)
for msg in unflagged: for msg in unflagged:
self.bus.post(MailUnflaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)) self.bus.post(
MailUnflaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
)
def _check_mailboxes(self) -> List[Tuple[Dict[int, Mail], Dict[int, Mail]]]: def _check_mailboxes(self) -> List[Tuple[Dict[int, Mail], Dict[int, Mail]]]:
workers = [] workers = []
@ -245,8 +290,14 @@ class MailBackend(Backend):
for mbox in self.mailboxes: for mbox in self.mailboxes:
unread_queue, flagged_queue = [Queue()] * 2 unread_queue, flagged_queue = [Queue()] * 2
worker = Thread(target=self._check_thread(unread_queue=unread_queue, flagged_queue=flagged_queue, worker = Thread(
plugin=mbox.plugin, **mbox.args)) target=self._check_thread(
unread_queue=unread_queue,
flagged_queue=flagged_queue,
plugin=mbox.plugin,
**mbox.args
)
)
worker.start() worker.start()
workers.append(worker) workers.append(worker)
queues.append((unread_queue, flagged_queue)) queues.append((unread_queue, flagged_queue))
@ -260,7 +311,11 @@ class MailBackend(Backend):
flagged = flagged_queue.get(timeout=self.timeout) flagged = flagged_queue.get(timeout=self.timeout)
results.append((unread, flagged)) results.append((unread, flagged))
except Empty: except Empty:
self.logger.warning('Checks on mailbox #{} timed out after {} seconds'.format(i + 1, self.timeout)) self.logger.warning(
'Checks on mailbox #{} timed out after {} seconds'.format(
i + 1, self.timeout
)
)
continue continue
return results return results
@ -276,16 +331,25 @@ class MailBackend(Backend):
for i, (unread, flagged) in enumerate(results): for i, (unread, flagged) in enumerate(results):
unread_msgs, seen_msgs = self._get_unread_seen_msgs(i, unread) unread_msgs, seen_msgs = self._get_unread_seen_msgs(i, unread)
flagged_msgs, unflagged_msgs = self._get_flagged_unflagged_msgs(i, flagged) flagged_msgs, unflagged_msgs = self._get_flagged_unflagged_msgs(i, flagged)
self._process_msg_events(i, unread=list(unread_msgs.values()), seen=list(seen_msgs.values()), self._process_msg_events(
flagged=list(flagged_msgs.values()), unflagged=list(unflagged_msgs.values()), i,
last_checked_date=mailbox_statuses[i].last_checked_date) unread=list(unread_msgs.values()),
seen=list(seen_msgs.values()),
flagged=list(flagged_msgs.values()),
unflagged=list(unflagged_msgs.values()),
last_checked_date=mailbox_statuses[i].last_checked_date,
)
self._unread_msgs[i] = unread self._unread_msgs[i] = unread
self._flagged_msgs[i] = flagged self._flagged_msgs[i] = flagged
records.append(MailboxStatus(mailbox_id=i, records.append(
unseen_message_ids=json.dumps([msg_id for msg_id in unread.keys()]), MailboxStatus(
flagged_message_ids=json.dumps([msg_id for msg_id in flagged.keys()]), mailbox_id=i,
last_checked_date=datetime.now())) unseen_message_ids=json.dumps(list(unread.keys())),
flagged_message_ids=json.dumps(list(flagged.keys())),
last_checked_date=datetime.now(),
)
)
with self._db_lock: with self._db_lock:
session = Session() session = Session()

View file

@ -5,7 +5,7 @@ from typing import Mapping, Type
import pkgutil import pkgutil
from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import declarative_base
Base = declarative_base() Base = declarative_base()
entities_registry: Mapping[Type['Entity'], Mapping] = {} entities_registry: Mapping[Type['Entity'], Mapping] = {}
@ -24,14 +24,16 @@ class Entity(Base):
type = Column(String, nullable=False, index=True) type = Column(String, nullable=False, index=True)
plugin = Column(String, nullable=False) plugin = Column(String, nullable=False)
data = Column(JSON, default=dict) data = Column(JSON, default=dict)
created_at = Column(DateTime(timezone=False), default=datetime.utcnow(), nullable=False) created_at = Column(
updated_at = Column(DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now()) DateTime(timezone=False), default=datetime.utcnow(), nullable=False
)
updated_at = Column(
DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now()
)
UniqueConstraint(external_id, plugin) UniqueConstraint(external_id, plugin)
__table_args__ = ( __table_args__ = (Index(name, plugin),)
Index(name, plugin),
)
__mapper_args__ = { __mapper_args__ = {
'polymorphic_identity': __tablename__, 'polymorphic_identity': __tablename__,
@ -41,13 +43,14 @@ class Entity(Base):
def _discover_entity_types(): def _discover_entity_types():
from platypush.context import get_plugin from platypush.context import get_plugin
logger = get_plugin('logger') logger = get_plugin('logger')
assert logger assert logger
for loader, modname, _ in pkgutil.walk_packages( for loader, modname, _ in pkgutil.walk_packages(
path=[str(pathlib.Path(__file__).parent.absolute())], path=[str(pathlib.Path(__file__).parent.absolute())],
prefix=__package__ + '.', prefix=__package__ + '.',
onerror=lambda _: None onerror=lambda _: None,
): ):
try: try:
mod_loader = loader.find_module(modname) # type: ignore mod_loader = loader.find_module(modname) # type: ignore
@ -65,9 +68,9 @@ def _discover_entity_types():
def init_entities_db(): def init_entities_db():
from platypush.context import get_plugin from platypush.context import get_plugin
_discover_entity_types() _discover_entity_types()
db = get_plugin('db') db = get_plugin('db')
assert db assert db
engine = db.get_engine() engine = db.get_engine()
db.create_all(engine, Base) db.create_all(engine, Base)

View file

@ -3,9 +3,16 @@ import os
import re import re
import time import time
from sqlalchemy import create_engine, Column, Integer, String, DateTime, PrimaryKeyConstraint, ForeignKey from sqlalchemy import (
from sqlalchemy.orm import sessionmaker, scoped_session create_engine,
from sqlalchemy.ext.declarative import declarative_base Column,
Integer,
String,
DateTime,
PrimaryKeyConstraint,
ForeignKey,
)
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from sqlalchemy.sql.expression import func from sqlalchemy.sql.expression import func
from platypush.config import Config from platypush.config import Config
@ -38,7 +45,8 @@ class LocalMediaSearcher(MediaSearcher):
if not self._db_engine: if not self._db_engine:
self._db_engine = create_engine( self._db_engine = create_engine(
'sqlite:///{}'.format(self.db_file), 'sqlite:///{}'.format(self.db_file),
connect_args={'check_same_thread': False}) connect_args={'check_same_thread': False},
)
Base.metadata.create_all(self._db_engine) Base.metadata.create_all(self._db_engine)
Session.configure(bind=self._db_engine) Session.configure(bind=self._db_engine)
@ -57,27 +65,30 @@ class LocalMediaSearcher(MediaSearcher):
@classmethod @classmethod
def _get_last_modify_time(cls, path, recursive=False): def _get_last_modify_time(cls, path, recursive=False):
return max([os.path.getmtime(p) for p, _, _ in os.walk(path)]) \ return (
if recursive else os.path.getmtime(path) max([os.path.getmtime(p) for p, _, _ in os.walk(path)])
if recursive
else os.path.getmtime(path)
)
@classmethod @classmethod
def _has_directory_changed_since_last_indexing(self, dir_record): def _has_directory_changed_since_last_indexing(cls, dir_record):
if not dir_record.last_indexed_at: if not dir_record.last_indexed_at:
return True return True
return datetime.datetime.fromtimestamp( return (
self._get_last_modify_time(dir_record.path)) > dir_record.last_indexed_at datetime.datetime.fromtimestamp(cls._get_last_modify_time(dir_record.path))
> dir_record.last_indexed_at
)
@classmethod @classmethod
def _matches_query(cls, filename, query): def _matches_query(cls, filename, query):
filename = filename.lower() filename = filename.lower()
query_tokens = [_.lower() for _ in re.split( query_tokens = [
cls._filename_separators, query.strip())] _.lower() for _ in re.split(cls._filename_separators, query.strip())
]
for token in query_tokens: return all(token in filename for token in query_tokens)
if token not in filename:
return False
return True
@classmethod @classmethod
def _sync_token_records(cls, session, *tokens): def _sync_token_records(cls, session, *tokens):
@ -85,9 +96,12 @@ class LocalMediaSearcher(MediaSearcher):
if not tokens: if not tokens:
return [] return []
records = {record.token: record for record in records = {
session.query(MediaToken).filter( record.token: record
MediaToken.token.in_(tokens)).all()} for record in session.query(MediaToken)
.filter(MediaToken.token.in_(tokens))
.all()
}
for token in tokens: for token in tokens:
if token in records: if token in records:
@ -97,13 +111,11 @@ class LocalMediaSearcher(MediaSearcher):
records[token] = record records[token] = record
session.commit() session.commit()
return session.query(MediaToken).filter( return session.query(MediaToken).filter(MediaToken.token.in_(tokens)).all()
MediaToken.token.in_(tokens)).all()
@classmethod @classmethod
def _get_file_records(cls, dir_record, session): def _get_file_records(cls, dir_record, session):
return session.query(MediaFile).filter_by( return session.query(MediaFile).filter_by(directory_id=dir_record.id).all()
directory_id=dir_record.id).all()
def scan(self, media_dir, session=None, dir_record=None): def scan(self, media_dir, session=None, dir_record=None):
""" """
@ -121,17 +133,19 @@ class LocalMediaSearcher(MediaSearcher):
dir_record = self._get_or_create_dir_entry(session, media_dir) dir_record = self._get_or_create_dir_entry(session, media_dir)
if not os.path.isdir(media_dir): if not os.path.isdir(media_dir):
self.logger.info('Directory {} is no longer accessible, removing it'. self.logger.info(
format(media_dir)) 'Directory {} is no longer accessible, removing it'.format(media_dir)
session.query(MediaDirectory) \ )
.filter(MediaDirectory.path == media_dir) \ session.query(MediaDirectory).filter(
.delete(synchronize_session='fetch') MediaDirectory.path == media_dir
).delete(synchronize_session='fetch')
return return
stored_file_records = { stored_file_records = {
f.path: f for f in self._get_file_records(dir_record, session)} f.path: f for f in self._get_file_records(dir_record, session)
}
for path, dirs, files in os.walk(media_dir): for path, _, files in os.walk(media_dir):
for filename in files: for filename in files:
filepath = os.path.join(path, filename) filepath = os.path.join(path, filename)
@ -142,26 +156,32 @@ class LocalMediaSearcher(MediaSearcher):
del stored_file_records[filepath] del stored_file_records[filepath]
continue continue
if not MediaPlugin.is_video_file(filename) and \ if not MediaPlugin.is_video_file(
not MediaPlugin.is_audio_file(filename): filename
) and not MediaPlugin.is_audio_file(filename):
continue continue
self.logger.debug('Syncing item {}'.format(filepath)) self.logger.debug('Syncing item {}'.format(filepath))
tokens = [_.lower() for _ in re.split(self._filename_separators, tokens = [
filename.strip())] _.lower()
for _ in re.split(self._filename_separators, filename.strip())
]
token_records = self._sync_token_records(session, *tokens) token_records = self._sync_token_records(session, *tokens)
file_record = MediaFile.build(directory_id=dir_record.id, file_record = MediaFile.build(directory_id=dir_record.id, path=filepath)
path=filepath)
session.add(file_record) session.add(file_record)
session.commit() session.commit()
file_record = session.query(MediaFile).filter_by( file_record = (
directory_id=dir_record.id, path=filepath).one() session.query(MediaFile)
.filter_by(directory_id=dir_record.id, path=filepath)
.one()
)
for token_record in token_records: for token_record in token_records:
file_token = MediaFileToken.build(file_id=file_record.id, file_token = MediaFileToken.build(
token_id=token_record.id) file_id=file_record.id, token_id=token_record.id
)
session.add(file_token) session.add(file_token)
# stored_file_records should now only contain the records of the files # stored_file_records should now only contain the records of the files
@ -169,15 +189,20 @@ class LocalMediaSearcher(MediaSearcher):
if stored_file_records: if stored_file_records:
self.logger.info( self.logger.info(
'Removing references to {} deleted media items from {}'.format( 'Removing references to {} deleted media items from {}'.format(
len(stored_file_records), media_dir)) len(stored_file_records), media_dir
)
)
session.query(MediaFile).filter(MediaFile.id.in_( session.query(MediaFile).filter(
[record.id for record in stored_file_records.values()] MediaFile.id.in_([record.id for record in stored_file_records.values()])
)).delete(synchronize_session='fetch') ).delete(synchronize_session='fetch')
dir_record.last_indexed_at = datetime.datetime.now() dir_record.last_indexed_at = datetime.datetime.now()
self.logger.info('Scanned {} in {} seconds'.format( self.logger.info(
media_dir, int(time.time() - index_start_time))) 'Scanned {} in {} seconds'.format(
media_dir, int(time.time() - index_start_time)
)
)
session.commit() session.commit()
@ -197,25 +222,30 @@ class LocalMediaSearcher(MediaSearcher):
dir_record = self._get_or_create_dir_entry(session, media_dir) dir_record = self._get_or_create_dir_entry(session, media_dir)
if self._has_directory_changed_since_last_indexing(dir_record): if self._has_directory_changed_since_last_indexing(dir_record):
self.logger.info('{} has changed since last indexing, '.format( self.logger.info(
media_dir) + 're-indexing') '{} has changed since last indexing, '.format(media_dir)
+ 're-indexing'
)
self.scan(media_dir, session=session, dir_record=dir_record) self.scan(media_dir, session=session, dir_record=dir_record)
query_tokens = [_.lower() for _ in re.split( query_tokens = [
self._filename_separators, query.strip())] _.lower() for _ in re.split(self._filename_separators, query.strip())
]
for file_record in session.query(MediaFile.path). \ for file_record in (
join(MediaFileToken). \ session.query(MediaFile.path)
join(MediaToken). \ .join(MediaFileToken)
filter(MediaToken.token.in_(query_tokens)). \ .join(MediaToken)
group_by(MediaFile.path). \ .filter(MediaToken.token.in_(query_tokens))
having(func.count(MediaFileToken.token_id) >= len(query_tokens)): .group_by(MediaFile.path)
.having(func.count(MediaFileToken.token_id) >= len(query_tokens))
):
if os.path.isfile(file_record.path): if os.path.isfile(file_record.path):
results[file_record.path] = { results[file_record.path] = {
'url': 'file://' + file_record.path, 'url': 'file://' + file_record.path,
'title': os.path.basename(file_record.path), 'title': os.path.basename(file_record.path),
'size': os.path.getsize(file_record.path) 'size': os.path.getsize(file_record.path),
} }
return results.values() return results.values()
@ -223,11 +253,12 @@ class LocalMediaSearcher(MediaSearcher):
# --- Table definitions # --- Table definitions
class MediaDirectory(Base): class MediaDirectory(Base):
"""Models the MediaDirectory table""" """Models the MediaDirectory table"""
__tablename__ = 'MediaDirectory' __tablename__ = 'MediaDirectory'
__table_args__ = ({'sqlite_autoincrement': True}) __table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
path = Column(String) path = Column(String)
@ -246,11 +277,12 @@ class MediaFile(Base):
"""Models the MediaFile table""" """Models the MediaFile table"""
__tablename__ = 'MediaFile' __tablename__ = 'MediaFile'
__table_args__ = ({'sqlite_autoincrement': True}) __table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
directory_id = Column(Integer, ForeignKey( directory_id = Column(
'MediaDirectory.id', ondelete='CASCADE'), nullable=False) Integer, ForeignKey('MediaDirectory.id', ondelete='CASCADE'), nullable=False
)
path = Column(String, nullable=False, unique=True) path = Column(String, nullable=False, unique=True)
indexed_at = Column(DateTime) indexed_at = Column(DateTime)
@ -268,7 +300,7 @@ class MediaToken(Base):
"""Models the MediaToken table""" """Models the MediaToken table"""
__tablename__ = 'MediaToken' __tablename__ = 'MediaToken'
__table_args__ = ({'sqlite_autoincrement': True}) __table_args__ = {'sqlite_autoincrement': True}
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
token = Column(String, nullable=False, unique=True) token = Column(String, nullable=False, unique=True)
@ -286,10 +318,12 @@ class MediaFileToken(Base):
__tablename__ = 'MediaFileToken' __tablename__ = 'MediaFileToken'
file_id = Column(Integer, ForeignKey('MediaFile.id', ondelete='CASCADE'), file_id = Column(
nullable=False) Integer, ForeignKey('MediaFile.id', ondelete='CASCADE'), nullable=False
token_id = Column(Integer, ForeignKey('MediaToken.id', ondelete='CASCADE'), )
nullable=False) token_id = Column(
Integer, ForeignKey('MediaToken.id', ondelete='CASCADE'), nullable=False
)
__table_args__ = (PrimaryKeyConstraint(file_id, token_id), {}) __table_args__ = (PrimaryKeyConstraint(file_id, token_id), {})
@ -301,4 +335,5 @@ class MediaFileToken(Base):
record.token_id = token_id record.token_id = token_id
return record return record
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -13,11 +13,13 @@ except ImportError:
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import make_transient from sqlalchemy.orm import make_transient, declarative_base
from sqlalchemy.ext.declarative import declarative_base
from platypush.context import get_plugin from platypush.context import get_plugin
from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException from platypush.exceptions.user import (
InvalidJWTTokenException,
InvalidCredentialsException,
)
from platypush.utils import get_or_generate_jwt_rsa_key_pair from platypush.utils import get_or_generate_jwt_rsa_key_pair
Base = declarative_base() Base = declarative_base()
@ -68,8 +70,12 @@ class UserManager:
if user: if user:
raise NameError('The user {} already exists'.format(username)) raise NameError('The user {} already exists'.format(username))
record = User(username=username, password=self._encrypt_password(password), record = User(
created_at=datetime.datetime.utcnow(), **kwargs) username=username,
password=self._encrypt_password(password),
created_at=datetime.datetime.utcnow(),
**kwargs
)
session.add(record) session.add(record)
session.commit() session.commit()
@ -93,10 +99,16 @@ class UserManager:
def authenticate_user_session(self, session_token): def authenticate_user_session(self, session_token):
with self.db.get_session() as session: with self.db.get_session() as session:
user_session = session.query(UserSession).filter_by(session_token=session_token).first() user_session = (
session.query(UserSession)
.filter_by(session_token=session_token)
.first()
)
if not user_session or ( if not user_session or (
user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()): user_session.expires_at
and user_session.expires_at < datetime.datetime.utcnow()
):
return None, None return None, None
user = session.query(User).filter_by(user_id=user_session.user_id).first() user = session.query(User).filter_by(user_id=user_session.user_id).first()
@ -108,7 +120,9 @@ class UserManager:
if not user: if not user:
raise NameError('No such user: {}'.format(username)) raise NameError('No such user: {}'.format(username))
user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all() user_sessions = (
session.query(UserSession).filter_by(user_id=user.user_id).all()
)
for user_session in user_sessions: for user_session in user_sessions:
session.delete(user_session) session.delete(user_session)
@ -118,7 +132,11 @@ class UserManager:
def delete_user_session(self, session_token): def delete_user_session(self, session_token):
with self.db.get_session() as session: with self.db.get_session() as session:
user_session = session.query(UserSession).filter_by(session_token=session_token).first() user_session = (
session.query(UserSession)
.filter_by(session_token=session_token)
.first()
)
if not user_session: if not user_session:
return False return False
@ -134,14 +152,18 @@ class UserManager:
return None return None
if expires_at: if expires_at:
if isinstance(expires_at, int) or isinstance(expires_at, float): if isinstance(expires_at, (int, float)):
expires_at = datetime.datetime.fromtimestamp(expires_at) expires_at = datetime.datetime.fromtimestamp(expires_at)
elif isinstance(expires_at, str): elif isinstance(expires_at, str):
expires_at = datetime.datetime.fromisoformat(expires_at) expires_at = datetime.datetime.fromisoformat(expires_at)
user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(), user_session = UserSession(
csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(), user_id=user.user_id,
expires_at=expires_at) session_token=self.generate_session_token(),
csrf_token=self.generate_session_token(),
created_at=datetime.datetime.utcnow(),
expires_at=expires_at,
)
session.add(user_session) session.add(user_session)
session.commit() session.commit()
@ -179,9 +201,19 @@ class UserManager:
:param session_token: Session token. :param session_token: Session token.
""" """
with self.db.get_session() as session: with self.db.get_session() as session:
return session.query(User).join(UserSession).filter_by(session_token=session_token).first() return (
session.query(User)
.join(UserSession)
.filter_by(session_token=session_token)
.first()
)
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str: def generate_jwt_token(
self,
username: str,
password: str,
expires_at: Optional[datetime.datetime] = None,
) -> str:
""" """
Create a user JWT token for API usage. Create a user JWT token for API usage.
@ -256,7 +288,7 @@ class User(Base):
"""Models the User table""" """Models the User table"""
__tablename__ = 'user' __tablename__ = 'user'
__table_args__ = ({'sqlite_autoincrement': True}) __table_args__ = {'sqlite_autoincrement': True}
user_id = Column(Integer, primary_key=True) user_id = Column(Integer, primary_key=True)
username = Column(String, unique=True, nullable=False) username = Column(String, unique=True, nullable=False)
@ -268,7 +300,7 @@ class UserSession(Base):
"""Models the UserSession table""" """Models the UserSession table"""
__tablename__ = 'user_session' __tablename__ = 'user_session'
__table_args__ = ({'sqlite_autoincrement': True}) __table_args__ = {'sqlite_autoincrement': True}
session_id = Column(Integer, primary_key=True) session_id = Column(Integer, primary_key=True)
session_token = Column(String, unique=True, nullable=False) session_token = Column(String, unique=True, nullable=False)