forked from platypush/platypush
Replaced deprecated sqlalchemy.ext.declarative with sqlalchemy.orm
This commit is contained in:
parent
4b7eeaa4ed
commit
8a70f1d38e
7 changed files with 540 additions and 252 deletions
|
@ -3,8 +3,7 @@ import os
|
|||
from typing import Optional, Union, List, Dict, Any
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||
|
||||
from platypush.backend import Backend
|
||||
from platypush.config import Config
|
||||
|
@ -17,10 +16,10 @@ Session = scoped_session(sessionmaker())
|
|||
|
||||
|
||||
class Covid19Update(Base):
|
||||
""" Models the Covid19Data table """
|
||||
"""Models the Covid19Data table"""
|
||||
|
||||
__tablename__ = 'covid19data'
|
||||
__table_args__ = ({'sqlite_autoincrement': True})
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
country = Column(String, primary_key=True)
|
||||
confirmed = Column(Integer, nullable=False, default=0)
|
||||
|
@ -40,7 +39,12 @@ class Covid19Backend(Backend):
|
|||
"""
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def __init__(self, country: Optional[Union[str, List[str]]], poll_seconds: Optional[float] = 3600.0, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
country: Optional[Union[str, List[str]]],
|
||||
poll_seconds: Optional[float] = 3600.0,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
:param country: Default country (or list of countries) to retrieve the stats for. It can either be the full
|
||||
country name or the country code. Special values:
|
||||
|
@ -56,7 +60,9 @@ class Covid19Backend(Backend):
|
|||
super().__init__(poll_seconds=poll_seconds, **kwargs)
|
||||
self._plugin: Covid19Plugin = get_plugin('covid19')
|
||||
self.country: List[str] = self._plugin._get_countries(country)
|
||||
self.workdir = os.path.join(os.path.expanduser(Config.get('workdir')), 'covid19')
|
||||
self.workdir = os.path.join(
|
||||
os.path.expanduser(Config.get('workdir')), 'covid19'
|
||||
)
|
||||
self.dbfile = os.path.join(self.workdir, 'data.db')
|
||||
os.makedirs(self.workdir, exist_ok=True)
|
||||
|
||||
|
@ -67,22 +73,30 @@ class Covid19Backend(Backend):
|
|||
self.logger.info('Stopped Covid19 backend')
|
||||
|
||||
def _process_update(self, summary: Dict[str, Any], session: Session):
|
||||
update_time = datetime.datetime.fromisoformat(summary['Date'].replace('Z', '+00:00'))
|
||||
update_time = datetime.datetime.fromisoformat(
|
||||
summary['Date'].replace('Z', '+00:00')
|
||||
)
|
||||
|
||||
self.bus.post(Covid19UpdateEvent(
|
||||
self.bus.post(
|
||||
Covid19UpdateEvent(
|
||||
country=summary['Country'],
|
||||
country_code=summary['CountryCode'],
|
||||
confirmed=summary['TotalConfirmed'],
|
||||
deaths=summary['TotalDeaths'],
|
||||
recovered=summary['TotalRecovered'],
|
||||
update_time=update_time,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
session.merge(Covid19Update(country=summary['CountryCode'],
|
||||
session.merge(
|
||||
Covid19Update(
|
||||
country=summary['CountryCode'],
|
||||
confirmed=summary['TotalConfirmed'],
|
||||
deaths=summary['TotalDeaths'],
|
||||
recovered=summary['TotalRecovered'],
|
||||
last_updated_at=update_time))
|
||||
last_updated_at=update_time,
|
||||
)
|
||||
)
|
||||
|
||||
def loop(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
|
@ -90,23 +104,30 @@ class Covid19Backend(Backend):
|
|||
if not summaries:
|
||||
return
|
||||
|
||||
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
|
||||
engine = create_engine(
|
||||
'sqlite:///{}'.format(self.dbfile),
|
||||
connect_args={'check_same_thread': False},
|
||||
)
|
||||
Base.metadata.create_all(engine)
|
||||
Session.configure(bind=engine)
|
||||
session = Session()
|
||||
|
||||
last_records = {
|
||||
record.country: record
|
||||
for record in session.query(Covid19Update).filter(Covid19Update.country.in_(self.country)).all()
|
||||
for record in session.query(Covid19Update)
|
||||
.filter(Covid19Update.country.in_(self.country))
|
||||
.all()
|
||||
}
|
||||
|
||||
for summary in summaries:
|
||||
country = summary['CountryCode']
|
||||
last_record = last_records.get(country)
|
||||
if not last_record or \
|
||||
summary['TotalConfirmed'] != last_record.confirmed or \
|
||||
summary['TotalDeaths'] != last_record.deaths or \
|
||||
summary['TotalRecovered'] != last_record.recovered:
|
||||
if (
|
||||
not last_record
|
||||
or summary['TotalConfirmed'] != last_record.confirmed
|
||||
or summary['TotalDeaths'] != last_record.deaths
|
||||
or summary['TotalRecovered'] != last_record.recovered
|
||||
):
|
||||
self._process_update(summary=summary, session=session)
|
||||
|
||||
session.commit()
|
||||
|
|
|
@ -6,15 +6,28 @@ from typing import Optional, List
|
|||
|
||||
import requests
|
||||
from sqlalchemy import create_engine, Column, String, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||
|
||||
from platypush.backend import Backend
|
||||
from platypush.config import Config
|
||||
from platypush.message.event.github import GithubPushEvent, GithubCommitCommentEvent, GithubCreateEvent, \
|
||||
GithubDeleteEvent, GithubEvent, GithubForkEvent, GithubWikiEvent, GithubIssueCommentEvent, GithubIssueEvent, \
|
||||
GithubMemberEvent, GithubPublicEvent, GithubPullRequestEvent, GithubPullRequestReviewCommentEvent, \
|
||||
GithubReleaseEvent, GithubSponsorshipEvent, GithubWatchEvent
|
||||
from platypush.message.event.github import (
|
||||
GithubPushEvent,
|
||||
GithubCommitCommentEvent,
|
||||
GithubCreateEvent,
|
||||
GithubDeleteEvent,
|
||||
GithubEvent,
|
||||
GithubForkEvent,
|
||||
GithubWikiEvent,
|
||||
GithubIssueCommentEvent,
|
||||
GithubIssueEvent,
|
||||
GithubMemberEvent,
|
||||
GithubPublicEvent,
|
||||
GithubPullRequestEvent,
|
||||
GithubPullRequestReviewCommentEvent,
|
||||
GithubReleaseEvent,
|
||||
GithubSponsorshipEvent,
|
||||
GithubWatchEvent,
|
||||
)
|
||||
|
||||
Base = declarative_base()
|
||||
Session = scoped_session(sessionmaker())
|
||||
|
@ -71,8 +84,17 @@ class GithubBackend(Backend):
|
|||
|
||||
_base_url = 'https://api.github.com'
|
||||
|
||||
def __init__(self, user: str, user_token: str, repos: Optional[List[str]] = None, org: Optional[str] = None,
|
||||
poll_seconds: int = 60, max_events_per_scan: Optional[int] = 10, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
user: str,
|
||||
user_token: str,
|
||||
repos: Optional[List[str]] = None,
|
||||
org: Optional[str] = None,
|
||||
poll_seconds: int = 60,
|
||||
max_events_per_scan: Optional[int] = 10,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
If neither ``repos`` nor ``org`` is specified then the backend will monitor all new events on user level.
|
||||
|
||||
|
@ -102,17 +124,23 @@ class GithubBackend(Backend):
|
|||
|
||||
def _request(self, uri: str, method: str = 'get') -> dict:
|
||||
method = getattr(requests, method.lower())
|
||||
return method(self._base_url + uri, auth=(self.user, self.user_token),
|
||||
headers={'Accept': 'application/vnd.github.v3+json'}).json()
|
||||
return method(
|
||||
self._base_url + uri,
|
||||
auth=(self.user, self.user_token),
|
||||
headers={'Accept': 'application/vnd.github.v3+json'},
|
||||
).json()
|
||||
|
||||
def _init_db(self):
|
||||
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
|
||||
engine = create_engine(
|
||||
'sqlite:///{}'.format(self.dbfile),
|
||||
connect_args={'check_same_thread': False},
|
||||
)
|
||||
Base.metadata.create_all(engine)
|
||||
Session.configure(bind=engine)
|
||||
|
||||
@staticmethod
|
||||
def _to_datetime(time_string: str) -> datetime.datetime:
|
||||
""" Convert ISO 8061 string format with leading 'Z' into something understandable by Python """
|
||||
"""Convert ISO 8061 string format with leading 'Z' into something understandable by Python"""
|
||||
return datetime.datetime.fromisoformat(time_string[:-1] + '+00:00')
|
||||
|
||||
@staticmethod
|
||||
|
@ -128,7 +156,11 @@ class GithubBackend(Backend):
|
|||
def _get_last_event_time(self, uri: str):
|
||||
with self.db_lock:
|
||||
record = self._get_or_create_resource(uri=uri, session=Session())
|
||||
return record.last_updated_at.replace(tzinfo=datetime.timezone.utc) if record.last_updated_at else None
|
||||
return (
|
||||
record.last_updated_at.replace(tzinfo=datetime.timezone.utc)
|
||||
if record.last_updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
def _update_last_event_time(self, uri: str, last_updated_at: datetime.datetime):
|
||||
with self.db_lock:
|
||||
|
@ -158,9 +190,18 @@ class GithubBackend(Backend):
|
|||
'WatchEvent': GithubWatchEvent,
|
||||
}
|
||||
|
||||
event_type = event_mapping[event['type']] if event['type'] in event_mapping else GithubEvent
|
||||
return event_type(event_type=event['type'], actor=event['actor'], repo=event.get('repo', {}),
|
||||
payload=event['payload'], created_at=cls._to_datetime(event['created_at']))
|
||||
event_type = (
|
||||
event_mapping[event['type']]
|
||||
if event['type'] in event_mapping
|
||||
else GithubEvent
|
||||
)
|
||||
return event_type(
|
||||
event_type=event['type'],
|
||||
actor=event['actor'],
|
||||
repo=event.get('repo', {}),
|
||||
payload=event['payload'],
|
||||
created_at=cls._to_datetime(event['created_at']),
|
||||
)
|
||||
|
||||
def _events_monitor(self, uri: str, method: str = 'get'):
|
||||
def thread():
|
||||
|
@ -175,7 +216,10 @@ class GithubBackend(Backend):
|
|||
fired_events = []
|
||||
|
||||
for event in events:
|
||||
if self.max_events_per_scan and len(fired_events) >= self.max_events_per_scan:
|
||||
if (
|
||||
self.max_events_per_scan
|
||||
and len(fired_events) >= self.max_events_per_scan
|
||||
):
|
||||
break
|
||||
|
||||
event_time = self._to_datetime(event['created_at'])
|
||||
|
@ -189,12 +233,17 @@ class GithubBackend(Backend):
|
|||
for event in fired_events:
|
||||
self.bus.post(event)
|
||||
|
||||
self._update_last_event_time(uri=uri, last_updated_at=new_last_event_time)
|
||||
self._update_last_event_time(
|
||||
uri=uri, last_updated_at=new_last_event_time
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning('Encountered exception while fetching events from {}: {}'.format(
|
||||
uri, str(e)))
|
||||
self.logger.warning(
|
||||
'Encountered exception while fetching events from {}: {}'.format(
|
||||
uri, str(e)
|
||||
)
|
||||
)
|
||||
self.logger.exception(e)
|
||||
finally:
|
||||
|
||||
if self.wait_stop(timeout=self.poll_seconds):
|
||||
break
|
||||
|
||||
|
@ -206,12 +255,30 @@ class GithubBackend(Backend):
|
|||
|
||||
if self.repos:
|
||||
for repo in self.repos:
|
||||
monitors.append(threading.Thread(target=self._events_monitor('/networks/{repo}/events'.format(repo=repo))))
|
||||
monitors.append(
|
||||
threading.Thread(
|
||||
target=self._events_monitor(
|
||||
'/networks/{repo}/events'.format(repo=repo)
|
||||
)
|
||||
)
|
||||
)
|
||||
if self.org:
|
||||
monitors.append(threading.Thread(target=self._events_monitor('/orgs/{org}/events'.format(org=self.org))))
|
||||
monitors.append(
|
||||
threading.Thread(
|
||||
target=self._events_monitor(
|
||||
'/orgs/{org}/events'.format(org=self.org)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not (self.repos or self.org):
|
||||
monitors.append(threading.Thread(target=self._events_monitor('/users/{user}/events'.format(user=self.user))))
|
||||
monitors.append(
|
||||
threading.Thread(
|
||||
target=self._events_monitor(
|
||||
'/users/{user}/events'.format(user=self.user)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for monitor in monitors:
|
||||
monitor.start()
|
||||
|
@ -222,4 +289,5 @@ class GithubBackend(Backend):
|
|||
|
||||
self.logger.info('Github backend terminated')
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -2,11 +2,17 @@ import datetime
|
|||
import enum
|
||||
import os
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, \
|
||||
Enum, ForeignKey
|
||||
from sqlalchemy import (
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||
from sqlalchemy.sql.expression import func
|
||||
|
||||
from platypush.backend.http.request import HttpRequest
|
||||
|
@ -44,18 +50,31 @@ class RssUpdates(HttpRequest):
|
|||
|
||||
"""
|
||||
|
||||
user_agent = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) ' + \
|
||||
'Chrome/62.0.3202.94 Safari/537.36'
|
||||
user_agent = (
|
||||
'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) '
|
||||
+ 'Chrome/62.0.3202.94 Safari/537.36'
|
||||
)
|
||||
|
||||
def __init__(self, url, title=None, headers=None, params=None, max_entries=None,
|
||||
extract_content=False, digest_format=None, user_agent: str = user_agent,
|
||||
body_style: str = 'font-size: 22px; ' +
|
||||
'font-family: "Merriweather", Georgia, "Times New Roman", Times, serif;',
|
||||
def __init__(
|
||||
self,
|
||||
url,
|
||||
title=None,
|
||||
headers=None,
|
||||
params=None,
|
||||
max_entries=None,
|
||||
extract_content=False,
|
||||
digest_format=None,
|
||||
user_agent: str = user_agent,
|
||||
body_style: str = 'font-size: 22px; '
|
||||
+ 'font-family: "Merriweather", Georgia, "Times New Roman", Times, serif;',
|
||||
title_style: str = 'margin-top: 30px',
|
||||
subtitle_style: str = 'margin-top: 10px; page-break-after: always',
|
||||
article_title_style: str = 'page-break-before: always',
|
||||
article_link_style: str = 'color: #555; text-decoration: none; border-bottom: 1px dotted',
|
||||
article_content_style: str = '', *argv, **kwargs):
|
||||
article_content_style: str = '',
|
||||
*argv,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
:param url: URL to the RSS feed to be monitored.
|
||||
:param title: Optional title for the feed.
|
||||
|
@ -91,7 +110,9 @@ class RssUpdates(HttpRequest):
|
|||
# If true, then the http.webpage plugin will be used to parse the content
|
||||
self.extract_content = extract_content
|
||||
|
||||
self.digest_format = digest_format.lower() if digest_format else None # Supported formats: html, pdf
|
||||
self.digest_format = (
|
||||
digest_format.lower() if digest_format else None
|
||||
) # Supported formats: html, pdf
|
||||
|
||||
os.makedirs(os.path.expanduser(os.path.dirname(self.dbfile)), exist_ok=True)
|
||||
|
||||
|
@ -119,7 +140,11 @@ class RssUpdates(HttpRequest):
|
|||
|
||||
@staticmethod
|
||||
def _get_latest_update(session, source_id):
|
||||
return session.query(func.max(FeedEntry.published)).filter_by(source_id=source_id).scalar()
|
||||
return (
|
||||
session.query(func.max(FeedEntry.published))
|
||||
.filter_by(source_id=source_id)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
def _parse_entry_content(self, link):
|
||||
self.logger.info('Extracting content from {}'.format(link))
|
||||
|
@ -130,14 +155,20 @@ class RssUpdates(HttpRequest):
|
|||
errors = response.errors
|
||||
|
||||
if not output:
|
||||
self.logger.warning('Mercury parser error: {}'.format(errors or '[unknown error]'))
|
||||
self.logger.warning(
|
||||
'Mercury parser error: {}'.format(errors or '[unknown error]')
|
||||
)
|
||||
return
|
||||
|
||||
return output.get('content')
|
||||
|
||||
def get_new_items(self, response):
|
||||
import feedparser
|
||||
engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
|
||||
|
||||
engine = create_engine(
|
||||
'sqlite:///{}'.format(self.dbfile),
|
||||
connect_args={'check_same_thread': False},
|
||||
)
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
Session.configure(bind=engine)
|
||||
|
@ -157,12 +188,16 @@ class RssUpdates(HttpRequest):
|
|||
|
||||
content = u'''
|
||||
<h1 style="{title_style}">{title}</h1>
|
||||
<h2 style="{subtitle_style}">Feeds digest generated on {creation_date}</h2>'''.\
|
||||
format(title_style=self.title_style, title=self.title, subtitle_style=self.subtitle_style,
|
||||
creation_date=datetime.datetime.now().strftime('%d %B %Y, %H:%M'))
|
||||
<h2 style="{subtitle_style}">Feeds digest generated on {creation_date}</h2>'''.format(
|
||||
title_style=self.title_style,
|
||||
title=self.title,
|
||||
subtitle_style=self.subtitle_style,
|
||||
creation_date=datetime.datetime.now().strftime('%d %B %Y, %H:%M'),
|
||||
)
|
||||
|
||||
self.logger.info('Parsed {:d} items from RSS feed <{}>'
|
||||
.format(len(feed.entries), self.url))
|
||||
self.logger.info(
|
||||
'Parsed {:d} items from RSS feed <{}>'.format(len(feed.entries), self.url)
|
||||
)
|
||||
|
||||
for entry in feed.entries:
|
||||
if not entry.published_parsed:
|
||||
|
@ -171,9 +206,10 @@ class RssUpdates(HttpRequest):
|
|||
try:
|
||||
entry_timestamp = datetime.datetime(*entry.published_parsed[:6])
|
||||
|
||||
if latest_update is None \
|
||||
or entry_timestamp > latest_update:
|
||||
self.logger.info('Processed new item from RSS feed <{}>'.format(self.url))
|
||||
if latest_update is None or entry_timestamp > latest_update:
|
||||
self.logger.info(
|
||||
'Processed new item from RSS feed <{}>'.format(self.url)
|
||||
)
|
||||
entry.summary = entry.summary if hasattr(entry, 'summary') else None
|
||||
|
||||
if self.extract_content:
|
||||
|
@ -188,9 +224,13 @@ class RssUpdates(HttpRequest):
|
|||
<a href="{link}" target="_blank" style="{article_link_style}">{title}</a>
|
||||
</h1>
|
||||
<div class="_parsed-content" style="{article_content_style}">{content}</div>'''.format(
|
||||
article_title_style=self.article_title_style, article_link_style=self.article_link_style,
|
||||
article_content_style=self.article_content_style, link=entry.link, title=entry.title,
|
||||
content=entry.content)
|
||||
article_title_style=self.article_title_style,
|
||||
article_link_style=self.article_link_style,
|
||||
article_content_style=self.article_content_style,
|
||||
link=entry.link,
|
||||
title=entry.title,
|
||||
content=entry.content,
|
||||
)
|
||||
|
||||
e = {
|
||||
'entry_id': entry.id,
|
||||
|
@ -207,21 +247,32 @@ class RssUpdates(HttpRequest):
|
|||
if self.max_entries and len(entries) > self.max_entries:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.warning('Exception encountered while parsing RSS ' +
|
||||
'RSS feed {}: {}'.format(entry.link, str(e)))
|
||||
self.logger.warning(
|
||||
'Exception encountered while parsing RSS '
|
||||
+ f'RSS feed {entry.link}: {e}'
|
||||
)
|
||||
self.logger.exception(e)
|
||||
|
||||
source_record.last_updated_at = parse_start_time
|
||||
digest_filename = None
|
||||
|
||||
if entries:
|
||||
self.logger.info('Parsed {} new entries from the RSS feed {}'.format(
|
||||
len(entries), self.title))
|
||||
self.logger.info(
|
||||
'Parsed {} new entries from the RSS feed {}'.format(
|
||||
len(entries), self.title
|
||||
)
|
||||
)
|
||||
|
||||
if self.digest_format:
|
||||
digest_filename = os.path.join(self.workdir, 'cache', '{}_{}.{}'.format(
|
||||
digest_filename = os.path.join(
|
||||
self.workdir,
|
||||
'cache',
|
||||
'{}_{}.{}'.format(
|
||||
datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'),
|
||||
self.title, self.digest_format))
|
||||
self.title,
|
||||
self.digest_format,
|
||||
),
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(digest_filename), exist_ok=True)
|
||||
|
||||
|
@ -233,12 +284,15 @@ class RssUpdates(HttpRequest):
|
|||
</head>
|
||||
<body style="{body_style}">{content}</body>
|
||||
</html>
|
||||
'''.format(title=self.title, body_style=self.body_style, content=content)
|
||||
'''.format(
|
||||
title=self.title, body_style=self.body_style, content=content
|
||||
)
|
||||
|
||||
with open(digest_filename, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
elif self.digest_format == 'pdf':
|
||||
from weasyprint import HTML, CSS
|
||||
|
||||
try:
|
||||
from weasyprint.fonts import FontConfiguration
|
||||
except ImportError:
|
||||
|
@ -246,37 +300,47 @@ class RssUpdates(HttpRequest):
|
|||
|
||||
body_style = 'body { ' + self.body_style + ' }'
|
||||
font_config = FontConfiguration()
|
||||
css = [CSS('https://fonts.googleapis.com/css?family=Merriweather'),
|
||||
CSS(string=body_style, font_config=font_config)]
|
||||
css = [
|
||||
CSS('https://fonts.googleapis.com/css?family=Merriweather'),
|
||||
CSS(string=body_style, font_config=font_config),
|
||||
]
|
||||
|
||||
HTML(string=content).write_pdf(digest_filename, stylesheets=css)
|
||||
else:
|
||||
raise RuntimeError('Unsupported format: {}. Supported formats: ' +
|
||||
'html or pdf'.format(self.digest_format))
|
||||
raise RuntimeError(
|
||||
f'Unsupported format: {self.digest_format}. Supported formats: html, pdf'
|
||||
)
|
||||
|
||||
digest_entry = FeedDigest(source_id=source_record.id,
|
||||
digest_entry = FeedDigest(
|
||||
source_id=source_record.id,
|
||||
format=self.digest_format,
|
||||
filename=digest_filename)
|
||||
filename=digest_filename,
|
||||
)
|
||||
|
||||
session.add(digest_entry)
|
||||
self.logger.info('{} digest ready: {}'.format(self.digest_format, digest_filename))
|
||||
self.logger.info(
|
||||
'{} digest ready: {}'.format(self.digest_format, digest_filename)
|
||||
)
|
||||
|
||||
session.commit()
|
||||
self.logger.info('Parsing RSS feed {}: completed'.format(self.title))
|
||||
|
||||
return NewFeedEvent(request=dict(self), response=entries,
|
||||
return NewFeedEvent(
|
||||
request=dict(self),
|
||||
response=entries,
|
||||
source_id=source_record.id,
|
||||
source_title=source_record.title,
|
||||
title=self.title,
|
||||
digest_format=self.digest_format,
|
||||
digest_filename=digest_filename)
|
||||
digest_filename=digest_filename,
|
||||
)
|
||||
|
||||
|
||||
class FeedSource(Base):
|
||||
""" Models the FeedSource table, containing RSS sources to be parsed """
|
||||
"""Models the FeedSource table, containing RSS sources to be parsed"""
|
||||
|
||||
__tablename__ = 'FeedSource'
|
||||
__table_args__ = ({'sqlite_autoincrement': True})
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
title = Column(String)
|
||||
|
@ -285,10 +349,10 @@ class FeedSource(Base):
|
|||
|
||||
|
||||
class FeedEntry(Base):
|
||||
""" Models the FeedEntry table, which contains RSS entries """
|
||||
"""Models the FeedEntry table, which contains RSS entries"""
|
||||
|
||||
__tablename__ = 'FeedEntry'
|
||||
__table_args__ = ({'sqlite_autoincrement': True})
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
entry_id = Column(String)
|
||||
|
@ -301,15 +365,15 @@ class FeedEntry(Base):
|
|||
|
||||
|
||||
class FeedDigest(Base):
|
||||
""" Models the FeedDigest table, containing feed digests either in HTML
|
||||
or PDF format """
|
||||
"""Models the FeedDigest table, containing feed digests either in HTML
|
||||
or PDF format"""
|
||||
|
||||
class DigestFormat(enum.Enum):
|
||||
html = 1
|
||||
pdf = 2
|
||||
|
||||
__tablename__ = 'FeedDigest'
|
||||
__table_args__ = ({'sqlite_autoincrement': True})
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
source_id = Column(Integer, ForeignKey('FeedSource.id'), nullable=False)
|
||||
|
@ -317,4 +381,5 @@ class FeedDigest(Base):
|
|||
filename = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.utcnow)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -8,15 +8,18 @@ from queue import Queue, Empty
|
|||
from threading import Thread, RLock
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime
|
||||
import sqlalchemy.engine as engine
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import engine, create_engine, Column, Integer, String, DateTime
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||
|
||||
from platypush.backend import Backend
|
||||
from platypush.config import Config
|
||||
from platypush.context import get_plugin
|
||||
from platypush.message.event.mail import MailReceivedEvent, MailSeenEvent, MailFlaggedEvent, MailUnflaggedEvent
|
||||
from platypush.message.event.mail import (
|
||||
MailReceivedEvent,
|
||||
MailSeenEvent,
|
||||
MailFlaggedEvent,
|
||||
MailUnflaggedEvent,
|
||||
)
|
||||
from platypush.plugins.mail import MailInPlugin, Mail
|
||||
|
||||
# <editor-fold desc="Database tables">
|
||||
|
@ -25,7 +28,8 @@ Session = scoped_session(sessionmaker())
|
|||
|
||||
|
||||
class MailboxStatus(Base):
|
||||
""" Models the MailboxStatus table, containing information about the state of a monitored mailbox. """
|
||||
"""Models the MailboxStatus table, containing information about the state of a monitored mailbox."""
|
||||
|
||||
__tablename__ = 'MailboxStatus'
|
||||
|
||||
mailbox_id = Column(Integer, primary_key=True)
|
||||
|
@ -64,8 +68,13 @@ class MailBackend(Backend):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, mailboxes: List[Dict[str, Any]], timeout: Optional[int] = 60, poll_seconds: Optional[int] = 60,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
mailboxes: List[Dict[str, Any]],
|
||||
timeout: Optional[int] = 60,
|
||||
poll_seconds: Optional[int] = 60,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
:param mailboxes: List of mailboxes to be monitored. Each mailbox entry contains a ``plugin`` attribute to
|
||||
identify the :class:`platypush.plugins.mail.MailInPlugin` plugin that will be used (e.g. ``mail.imap``)
|
||||
|
@ -128,9 +137,13 @@ class MailBackend(Backend):
|
|||
|
||||
# Parse mailboxes
|
||||
for i, mbox in enumerate(mailboxes):
|
||||
assert 'plugin' in mbox, 'No plugin attribute specified for mailbox n.{}'.format(i)
|
||||
assert (
|
||||
'plugin' in mbox
|
||||
), 'No plugin attribute specified for mailbox n.{}'.format(i)
|
||||
plugin = get_plugin(mbox.pop('plugin'))
|
||||
assert isinstance(plugin, MailInPlugin), '{} is not a MailInPlugin'.format(plugin)
|
||||
assert isinstance(plugin, MailInPlugin), '{} is not a MailInPlugin'.format(
|
||||
plugin
|
||||
)
|
||||
name = mbox.pop('name') if 'name' in mbox else 'Mailbox #{}'.format(i + 1)
|
||||
self.mailboxes.append(Mailbox(plugin=plugin, name=name, args=mbox))
|
||||
|
||||
|
@ -144,7 +157,10 @@ class MailBackend(Backend):
|
|||
|
||||
# <editor-fold desc="Database methods">
|
||||
def _db_get_engine(self) -> engine.Engine:
|
||||
return create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False})
|
||||
return create_engine(
|
||||
'sqlite:///{}'.format(self.dbfile),
|
||||
connect_args={'check_same_thread': False},
|
||||
)
|
||||
|
||||
def _db_load_mailboxes_status(self) -> None:
|
||||
mailbox_ids = list(range(len(self.mailboxes)))
|
||||
|
@ -153,12 +169,18 @@ class MailBackend(Backend):
|
|||
session = Session()
|
||||
records = {
|
||||
record.mailbox_id: record
|
||||
for record in session.query(MailboxStatus).filter(MailboxStatus.mailbox_id.in_(mailbox_ids)).all()
|
||||
for record in session.query(MailboxStatus)
|
||||
.filter(MailboxStatus.mailbox_id.in_(mailbox_ids))
|
||||
.all()
|
||||
}
|
||||
|
||||
for mbox_id, mbox in enumerate(self.mailboxes):
|
||||
for mbox_id, _ in enumerate(self.mailboxes):
|
||||
if mbox_id not in records:
|
||||
record = MailboxStatus(mailbox_id=mbox_id, unseen_message_ids='[]', flagged_message_ids='[]')
|
||||
record = MailboxStatus(
|
||||
mailbox_id=mbox_id,
|
||||
unseen_message_ids='[]',
|
||||
flagged_message_ids='[]',
|
||||
)
|
||||
session.add(record)
|
||||
else:
|
||||
record = records[mbox_id]
|
||||
|
@ -170,19 +192,25 @@ class MailBackend(Backend):
|
|||
|
||||
session.commit()
|
||||
|
||||
def _db_get_mailbox_status(self, mailbox_ids: List[int]) -> Dict[int, MailboxStatus]:
|
||||
def _db_get_mailbox_status(
|
||||
self, mailbox_ids: List[int]
|
||||
) -> Dict[int, MailboxStatus]:
|
||||
with self._db_lock:
|
||||
session = Session()
|
||||
return {
|
||||
record.mailbox_id: record
|
||||
for record in session.query(MailboxStatus).filter(MailboxStatus.mailbox_id.in_(mailbox_ids)).all()
|
||||
for record in session.query(MailboxStatus)
|
||||
.filter(MailboxStatus.mailbox_id.in_(mailbox_ids))
|
||||
.all()
|
||||
}
|
||||
|
||||
# </editor-fold>
|
||||
|
||||
# <editor-fold desc="Parse unread messages logic">
|
||||
@staticmethod
|
||||
def _check_thread(unread_queue: Queue, flagged_queue: Queue, plugin: MailInPlugin, **args):
|
||||
def _check_thread(
|
||||
unread_queue: Queue, flagged_queue: Queue, plugin: MailInPlugin, **args
|
||||
):
|
||||
def thread():
|
||||
# noinspection PyUnresolvedReferences
|
||||
unread = plugin.search_unseen_messages(**args).output
|
||||
|
@ -194,8 +222,9 @@ class MailBackend(Backend):
|
|||
|
||||
return thread
|
||||
|
||||
def _get_unread_seen_msgs(self, mailbox_idx: int, unread_msgs: Dict[int, Mail]) \
|
||||
-> Tuple[Dict[int, Mail], Dict[int, Mail]]:
|
||||
def _get_unread_seen_msgs(
|
||||
self, mailbox_idx: int, unread_msgs: Dict[int, Mail]
|
||||
) -> Tuple[Dict[int, Mail], Dict[int, Mail]]:
|
||||
prev_unread_msgs = self._unread_msgs[mailbox_idx]
|
||||
|
||||
return {
|
||||
|
@ -208,8 +237,9 @@ class MailBackend(Backend):
|
|||
if msg_id not in unread_msgs
|
||||
}
|
||||
|
||||
def _get_flagged_unflagged_msgs(self, mailbox_idx: int, flagged_msgs: Dict[int, Mail]) \
|
||||
-> Tuple[Dict[int, Mail], Dict[int, Mail]]:
|
||||
def _get_flagged_unflagged_msgs(
|
||||
self, mailbox_idx: int, flagged_msgs: Dict[int, Mail]
|
||||
) -> Tuple[Dict[int, Mail], Dict[int, Mail]]:
|
||||
prev_flagged_msgs = self._flagged_msgs[mailbox_idx]
|
||||
|
||||
return {
|
||||
|
@ -222,21 +252,36 @@ class MailBackend(Backend):
|
|||
if msg_id not in flagged_msgs
|
||||
}
|
||||
|
||||
def _process_msg_events(self, mailbox_id: int, unread: List[Mail], seen: List[Mail],
|
||||
flagged: List[Mail], unflagged: List[Mail], last_checked_date: Optional[datetime] = None):
|
||||
def _process_msg_events(
|
||||
self,
|
||||
mailbox_id: int,
|
||||
unread: List[Mail],
|
||||
seen: List[Mail],
|
||||
flagged: List[Mail],
|
||||
unflagged: List[Mail],
|
||||
last_checked_date: Optional[datetime] = None,
|
||||
):
|
||||
for msg in unread:
|
||||
if msg.date and last_checked_date and msg.date < last_checked_date:
|
||||
continue
|
||||
self.bus.post(MailReceivedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
|
||||
self.bus.post(
|
||||
MailReceivedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
|
||||
)
|
||||
|
||||
for msg in seen:
|
||||
self.bus.post(MailSeenEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
|
||||
self.bus.post(
|
||||
MailSeenEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
|
||||
)
|
||||
|
||||
for msg in flagged:
|
||||
self.bus.post(MailFlaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
|
||||
self.bus.post(
|
||||
MailFlaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
|
||||
)
|
||||
|
||||
for msg in unflagged:
|
||||
self.bus.post(MailUnflaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg))
|
||||
self.bus.post(
|
||||
MailUnflaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)
|
||||
)
|
||||
|
||||
def _check_mailboxes(self) -> List[Tuple[Dict[int, Mail], Dict[int, Mail]]]:
|
||||
workers = []
|
||||
|
@ -245,8 +290,14 @@ class MailBackend(Backend):
|
|||
|
||||
for mbox in self.mailboxes:
|
||||
unread_queue, flagged_queue = [Queue()] * 2
|
||||
worker = Thread(target=self._check_thread(unread_queue=unread_queue, flagged_queue=flagged_queue,
|
||||
plugin=mbox.plugin, **mbox.args))
|
||||
worker = Thread(
|
||||
target=self._check_thread(
|
||||
unread_queue=unread_queue,
|
||||
flagged_queue=flagged_queue,
|
||||
plugin=mbox.plugin,
|
||||
**mbox.args
|
||||
)
|
||||
)
|
||||
worker.start()
|
||||
workers.append(worker)
|
||||
queues.append((unread_queue, flagged_queue))
|
||||
|
@ -260,7 +311,11 @@ class MailBackend(Backend):
|
|||
flagged = flagged_queue.get(timeout=self.timeout)
|
||||
results.append((unread, flagged))
|
||||
except Empty:
|
||||
self.logger.warning('Checks on mailbox #{} timed out after {} seconds'.format(i + 1, self.timeout))
|
||||
self.logger.warning(
|
||||
'Checks on mailbox #{} timed out after {} seconds'.format(
|
||||
i + 1, self.timeout
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
@ -276,16 +331,25 @@ class MailBackend(Backend):
|
|||
for i, (unread, flagged) in enumerate(results):
|
||||
unread_msgs, seen_msgs = self._get_unread_seen_msgs(i, unread)
|
||||
flagged_msgs, unflagged_msgs = self._get_flagged_unflagged_msgs(i, flagged)
|
||||
self._process_msg_events(i, unread=list(unread_msgs.values()), seen=list(seen_msgs.values()),
|
||||
flagged=list(flagged_msgs.values()), unflagged=list(unflagged_msgs.values()),
|
||||
last_checked_date=mailbox_statuses[i].last_checked_date)
|
||||
self._process_msg_events(
|
||||
i,
|
||||
unread=list(unread_msgs.values()),
|
||||
seen=list(seen_msgs.values()),
|
||||
flagged=list(flagged_msgs.values()),
|
||||
unflagged=list(unflagged_msgs.values()),
|
||||
last_checked_date=mailbox_statuses[i].last_checked_date,
|
||||
)
|
||||
|
||||
self._unread_msgs[i] = unread
|
||||
self._flagged_msgs[i] = flagged
|
||||
records.append(MailboxStatus(mailbox_id=i,
|
||||
unseen_message_ids=json.dumps([msg_id for msg_id in unread.keys()]),
|
||||
flagged_message_ids=json.dumps([msg_id for msg_id in flagged.keys()]),
|
||||
last_checked_date=datetime.now()))
|
||||
records.append(
|
||||
MailboxStatus(
|
||||
mailbox_id=i,
|
||||
unseen_message_ids=json.dumps(list(unread.keys())),
|
||||
flagged_message_ids=json.dumps(list(flagged.keys())),
|
||||
last_checked_date=datetime.now(),
|
||||
)
|
||||
)
|
||||
|
||||
with self._db_lock:
|
||||
session = Session()
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Mapping, Type
|
|||
|
||||
import pkgutil
|
||||
from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
entities_registry: Mapping[Type['Entity'], Mapping] = {}
|
||||
|
@ -24,14 +24,16 @@ class Entity(Base):
|
|||
type = Column(String, nullable=False, index=True)
|
||||
plugin = Column(String, nullable=False)
|
||||
data = Column(JSON, default=dict)
|
||||
created_at = Column(DateTime(timezone=False), default=datetime.utcnow(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now())
|
||||
created_at = Column(
|
||||
DateTime(timezone=False), default=datetime.utcnow(), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now()
|
||||
)
|
||||
|
||||
UniqueConstraint(external_id, plugin)
|
||||
|
||||
__table_args__ = (
|
||||
Index(name, plugin),
|
||||
)
|
||||
__table_args__ = (Index(name, plugin),)
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': __tablename__,
|
||||
|
@ -41,13 +43,14 @@ class Entity(Base):
|
|||
|
||||
def _discover_entity_types():
|
||||
from platypush.context import get_plugin
|
||||
|
||||
logger = get_plugin('logger')
|
||||
assert logger
|
||||
|
||||
for loader, modname, _ in pkgutil.walk_packages(
|
||||
path=[str(pathlib.Path(__file__).parent.absolute())],
|
||||
prefix=__package__ + '.',
|
||||
onerror=lambda _: None
|
||||
onerror=lambda _: None,
|
||||
):
|
||||
try:
|
||||
mod_loader = loader.find_module(modname) # type: ignore
|
||||
|
@ -65,9 +68,9 @@ def _discover_entity_types():
|
|||
|
||||
def init_entities_db():
|
||||
from platypush.context import get_plugin
|
||||
|
||||
_discover_entity_types()
|
||||
db = get_plugin('db')
|
||||
assert db
|
||||
engine = db.get_engine()
|
||||
db.create_all(engine, Base)
|
||||
|
||||
|
|
|
@ -3,9 +3,16 @@ import os
|
|||
import re
|
||||
import time
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, PrimaryKeyConstraint, ForeignKey
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import (
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
PrimaryKeyConstraint,
|
||||
ForeignKey,
|
||||
)
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||
from sqlalchemy.sql.expression import func
|
||||
|
||||
from platypush.config import Config
|
||||
|
@ -38,7 +45,8 @@ class LocalMediaSearcher(MediaSearcher):
|
|||
if not self._db_engine:
|
||||
self._db_engine = create_engine(
|
||||
'sqlite:///{}'.format(self.db_file),
|
||||
connect_args={'check_same_thread': False})
|
||||
connect_args={'check_same_thread': False},
|
||||
)
|
||||
|
||||
Base.metadata.create_all(self._db_engine)
|
||||
Session.configure(bind=self._db_engine)
|
||||
|
@ -57,27 +65,30 @@ class LocalMediaSearcher(MediaSearcher):
|
|||
|
||||
@classmethod
|
||||
def _get_last_modify_time(cls, path, recursive=False):
|
||||
return max([os.path.getmtime(p) for p, _, _ in os.walk(path)]) \
|
||||
if recursive else os.path.getmtime(path)
|
||||
return (
|
||||
max([os.path.getmtime(p) for p, _, _ in os.walk(path)])
|
||||
if recursive
|
||||
else os.path.getmtime(path)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _has_directory_changed_since_last_indexing(self, dir_record):
|
||||
def _has_directory_changed_since_last_indexing(cls, dir_record):
|
||||
if not dir_record.last_indexed_at:
|
||||
return True
|
||||
|
||||
return datetime.datetime.fromtimestamp(
|
||||
self._get_last_modify_time(dir_record.path)) > dir_record.last_indexed_at
|
||||
return (
|
||||
datetime.datetime.fromtimestamp(cls._get_last_modify_time(dir_record.path))
|
||||
> dir_record.last_indexed_at
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _matches_query(cls, filename, query):
|
||||
filename = filename.lower()
|
||||
query_tokens = [_.lower() for _ in re.split(
|
||||
cls._filename_separators, query.strip())]
|
||||
query_tokens = [
|
||||
_.lower() for _ in re.split(cls._filename_separators, query.strip())
|
||||
]
|
||||
|
||||
for token in query_tokens:
|
||||
if token not in filename:
|
||||
return False
|
||||
return True
|
||||
return all(token in filename for token in query_tokens)
|
||||
|
||||
@classmethod
|
||||
def _sync_token_records(cls, session, *tokens):
|
||||
|
@ -85,9 +96,12 @@ class LocalMediaSearcher(MediaSearcher):
|
|||
if not tokens:
|
||||
return []
|
||||
|
||||
records = {record.token: record for record in
|
||||
session.query(MediaToken).filter(
|
||||
MediaToken.token.in_(tokens)).all()}
|
||||
records = {
|
||||
record.token: record
|
||||
for record in session.query(MediaToken)
|
||||
.filter(MediaToken.token.in_(tokens))
|
||||
.all()
|
||||
}
|
||||
|
||||
for token in tokens:
|
||||
if token in records:
|
||||
|
@ -97,13 +111,11 @@ class LocalMediaSearcher(MediaSearcher):
|
|||
records[token] = record
|
||||
|
||||
session.commit()
|
||||
return session.query(MediaToken).filter(
|
||||
MediaToken.token.in_(tokens)).all()
|
||||
return session.query(MediaToken).filter(MediaToken.token.in_(tokens)).all()
|
||||
|
||||
@classmethod
|
||||
def _get_file_records(cls, dir_record, session):
|
||||
return session.query(MediaFile).filter_by(
|
||||
directory_id=dir_record.id).all()
|
||||
return session.query(MediaFile).filter_by(directory_id=dir_record.id).all()
|
||||
|
||||
def scan(self, media_dir, session=None, dir_record=None):
|
||||
"""
|
||||
|
@ -121,17 +133,19 @@ class LocalMediaSearcher(MediaSearcher):
|
|||
dir_record = self._get_or_create_dir_entry(session, media_dir)
|
||||
|
||||
if not os.path.isdir(media_dir):
|
||||
self.logger.info('Directory {} is no longer accessible, removing it'.
|
||||
format(media_dir))
|
||||
session.query(MediaDirectory) \
|
||||
.filter(MediaDirectory.path == media_dir) \
|
||||
.delete(synchronize_session='fetch')
|
||||
self.logger.info(
|
||||
'Directory {} is no longer accessible, removing it'.format(media_dir)
|
||||
)
|
||||
session.query(MediaDirectory).filter(
|
||||
MediaDirectory.path == media_dir
|
||||
).delete(synchronize_session='fetch')
|
||||
return
|
||||
|
||||
stored_file_records = {
|
||||
f.path: f for f in self._get_file_records(dir_record, session)}
|
||||
f.path: f for f in self._get_file_records(dir_record, session)
|
||||
}
|
||||
|
||||
for path, dirs, files in os.walk(media_dir):
|
||||
for path, _, files in os.walk(media_dir):
|
||||
for filename in files:
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
|
@ -142,26 +156,32 @@ class LocalMediaSearcher(MediaSearcher):
|
|||
del stored_file_records[filepath]
|
||||
continue
|
||||
|
||||
if not MediaPlugin.is_video_file(filename) and \
|
||||
not MediaPlugin.is_audio_file(filename):
|
||||
if not MediaPlugin.is_video_file(
|
||||
filename
|
||||
) and not MediaPlugin.is_audio_file(filename):
|
||||
continue
|
||||
|
||||
self.logger.debug('Syncing item {}'.format(filepath))
|
||||
tokens = [_.lower() for _ in re.split(self._filename_separators,
|
||||
filename.strip())]
|
||||
tokens = [
|
||||
_.lower()
|
||||
for _ in re.split(self._filename_separators, filename.strip())
|
||||
]
|
||||
|
||||
token_records = self._sync_token_records(session, *tokens)
|
||||
file_record = MediaFile.build(directory_id=dir_record.id,
|
||||
path=filepath)
|
||||
file_record = MediaFile.build(directory_id=dir_record.id, path=filepath)
|
||||
|
||||
session.add(file_record)
|
||||
session.commit()
|
||||
file_record = session.query(MediaFile).filter_by(
|
||||
directory_id=dir_record.id, path=filepath).one()
|
||||
file_record = (
|
||||
session.query(MediaFile)
|
||||
.filter_by(directory_id=dir_record.id, path=filepath)
|
||||
.one()
|
||||
)
|
||||
|
||||
for token_record in token_records:
|
||||
file_token = MediaFileToken.build(file_id=file_record.id,
|
||||
token_id=token_record.id)
|
||||
file_token = MediaFileToken.build(
|
||||
file_id=file_record.id, token_id=token_record.id
|
||||
)
|
||||
session.add(file_token)
|
||||
|
||||
# stored_file_records should now only contain the records of the files
|
||||
|
@ -169,15 +189,20 @@ class LocalMediaSearcher(MediaSearcher):
|
|||
if stored_file_records:
|
||||
self.logger.info(
|
||||
'Removing references to {} deleted media items from {}'.format(
|
||||
len(stored_file_records), media_dir))
|
||||
len(stored_file_records), media_dir
|
||||
)
|
||||
)
|
||||
|
||||
session.query(MediaFile).filter(MediaFile.id.in_(
|
||||
[record.id for record in stored_file_records.values()]
|
||||
)).delete(synchronize_session='fetch')
|
||||
session.query(MediaFile).filter(
|
||||
MediaFile.id.in_([record.id for record in stored_file_records.values()])
|
||||
).delete(synchronize_session='fetch')
|
||||
|
||||
dir_record.last_indexed_at = datetime.datetime.now()
|
||||
self.logger.info('Scanned {} in {} seconds'.format(
|
||||
media_dir, int(time.time() - index_start_time)))
|
||||
self.logger.info(
|
||||
'Scanned {} in {} seconds'.format(
|
||||
media_dir, int(time.time() - index_start_time)
|
||||
)
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
@ -197,25 +222,30 @@ class LocalMediaSearcher(MediaSearcher):
|
|||
dir_record = self._get_or_create_dir_entry(session, media_dir)
|
||||
|
||||
if self._has_directory_changed_since_last_indexing(dir_record):
|
||||
self.logger.info('{} has changed since last indexing, '.format(
|
||||
media_dir) + 're-indexing')
|
||||
self.logger.info(
|
||||
'{} has changed since last indexing, '.format(media_dir)
|
||||
+ 're-indexing'
|
||||
)
|
||||
|
||||
self.scan(media_dir, session=session, dir_record=dir_record)
|
||||
|
||||
query_tokens = [_.lower() for _ in re.split(
|
||||
self._filename_separators, query.strip())]
|
||||
query_tokens = [
|
||||
_.lower() for _ in re.split(self._filename_separators, query.strip())
|
||||
]
|
||||
|
||||
for file_record in session.query(MediaFile.path). \
|
||||
join(MediaFileToken). \
|
||||
join(MediaToken). \
|
||||
filter(MediaToken.token.in_(query_tokens)). \
|
||||
group_by(MediaFile.path). \
|
||||
having(func.count(MediaFileToken.token_id) >= len(query_tokens)):
|
||||
for file_record in (
|
||||
session.query(MediaFile.path)
|
||||
.join(MediaFileToken)
|
||||
.join(MediaToken)
|
||||
.filter(MediaToken.token.in_(query_tokens))
|
||||
.group_by(MediaFile.path)
|
||||
.having(func.count(MediaFileToken.token_id) >= len(query_tokens))
|
||||
):
|
||||
if os.path.isfile(file_record.path):
|
||||
results[file_record.path] = {
|
||||
'url': 'file://' + file_record.path,
|
||||
'title': os.path.basename(file_record.path),
|
||||
'size': os.path.getsize(file_record.path)
|
||||
'size': os.path.getsize(file_record.path),
|
||||
}
|
||||
|
||||
return results.values()
|
||||
|
@ -223,11 +253,12 @@ class LocalMediaSearcher(MediaSearcher):
|
|||
|
||||
# --- Table definitions
|
||||
|
||||
|
||||
class MediaDirectory(Base):
|
||||
""" Models the MediaDirectory table """
|
||||
"""Models the MediaDirectory table"""
|
||||
|
||||
__tablename__ = 'MediaDirectory'
|
||||
__table_args__ = ({'sqlite_autoincrement': True})
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
path = Column(String)
|
||||
|
@ -243,14 +274,15 @@ class MediaDirectory(Base):
|
|||
|
||||
|
||||
class MediaFile(Base):
|
||||
""" Models the MediaFile table """
|
||||
"""Models the MediaFile table"""
|
||||
|
||||
__tablename__ = 'MediaFile'
|
||||
__table_args__ = ({'sqlite_autoincrement': True})
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
directory_id = Column(Integer, ForeignKey(
|
||||
'MediaDirectory.id', ondelete='CASCADE'), nullable=False)
|
||||
directory_id = Column(
|
||||
Integer, ForeignKey('MediaDirectory.id', ondelete='CASCADE'), nullable=False
|
||||
)
|
||||
path = Column(String, nullable=False, unique=True)
|
||||
indexed_at = Column(DateTime)
|
||||
|
||||
|
@ -265,10 +297,10 @@ class MediaFile(Base):
|
|||
|
||||
|
||||
class MediaToken(Base):
|
||||
""" Models the MediaToken table """
|
||||
"""Models the MediaToken table"""
|
||||
|
||||
__tablename__ = 'MediaToken'
|
||||
__table_args__ = ({'sqlite_autoincrement': True})
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
token = Column(String, nullable=False, unique=True)
|
||||
|
@ -282,14 +314,16 @@ class MediaToken(Base):
|
|||
|
||||
|
||||
class MediaFileToken(Base):
|
||||
""" Models the MediaFileToken table """
|
||||
"""Models the MediaFileToken table"""
|
||||
|
||||
__tablename__ = 'MediaFileToken'
|
||||
|
||||
file_id = Column(Integer, ForeignKey('MediaFile.id', ondelete='CASCADE'),
|
||||
nullable=False)
|
||||
token_id = Column(Integer, ForeignKey('MediaToken.id', ondelete='CASCADE'),
|
||||
nullable=False)
|
||||
file_id = Column(
|
||||
Integer, ForeignKey('MediaFile.id', ondelete='CASCADE'), nullable=False
|
||||
)
|
||||
token_id = Column(
|
||||
Integer, ForeignKey('MediaToken.id', ondelete='CASCADE'), nullable=False
|
||||
)
|
||||
|
||||
__table_args__ = (PrimaryKeyConstraint(file_id, token_id), {})
|
||||
|
||||
|
@ -301,4 +335,5 @@ class MediaFileToken(Base):
|
|||
record.token_id = token_id
|
||||
return record
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -13,11 +13,13 @@ except ImportError:
|
|||
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import make_transient
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import make_transient, declarative_base
|
||||
|
||||
from platypush.context import get_plugin
|
||||
from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException
|
||||
from platypush.exceptions.user import (
|
||||
InvalidJWTTokenException,
|
||||
InvalidCredentialsException,
|
||||
)
|
||||
from platypush.utils import get_or_generate_jwt_rsa_key_pair
|
||||
|
||||
Base = declarative_base()
|
||||
|
@ -68,8 +70,12 @@ class UserManager:
|
|||
if user:
|
||||
raise NameError('The user {} already exists'.format(username))
|
||||
|
||||
record = User(username=username, password=self._encrypt_password(password),
|
||||
created_at=datetime.datetime.utcnow(), **kwargs)
|
||||
record = User(
|
||||
username=username,
|
||||
password=self._encrypt_password(password),
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
session.add(record)
|
||||
session.commit()
|
||||
|
@ -93,10 +99,16 @@ class UserManager:
|
|||
|
||||
def authenticate_user_session(self, session_token):
|
||||
with self.db.get_session() as session:
|
||||
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
|
||||
user_session = (
|
||||
session.query(UserSession)
|
||||
.filter_by(session_token=session_token)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_session or (
|
||||
user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()):
|
||||
user_session.expires_at
|
||||
and user_session.expires_at < datetime.datetime.utcnow()
|
||||
):
|
||||
return None, None
|
||||
|
||||
user = session.query(User).filter_by(user_id=user_session.user_id).first()
|
||||
|
@ -108,7 +120,9 @@ class UserManager:
|
|||
if not user:
|
||||
raise NameError('No such user: {}'.format(username))
|
||||
|
||||
user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all()
|
||||
user_sessions = (
|
||||
session.query(UserSession).filter_by(user_id=user.user_id).all()
|
||||
)
|
||||
for user_session in user_sessions:
|
||||
session.delete(user_session)
|
||||
|
||||
|
@ -118,7 +132,11 @@ class UserManager:
|
|||
|
||||
def delete_user_session(self, session_token):
|
||||
with self.db.get_session() as session:
|
||||
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
|
||||
user_session = (
|
||||
session.query(UserSession)
|
||||
.filter_by(session_token=session_token)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_session:
|
||||
return False
|
||||
|
@ -134,14 +152,18 @@ class UserManager:
|
|||
return None
|
||||
|
||||
if expires_at:
|
||||
if isinstance(expires_at, int) or isinstance(expires_at, float):
|
||||
if isinstance(expires_at, (int, float)):
|
||||
expires_at = datetime.datetime.fromtimestamp(expires_at)
|
||||
elif isinstance(expires_at, str):
|
||||
expires_at = datetime.datetime.fromisoformat(expires_at)
|
||||
|
||||
user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(),
|
||||
csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(),
|
||||
expires_at=expires_at)
|
||||
user_session = UserSession(
|
||||
user_id=user.user_id,
|
||||
session_token=self.generate_session_token(),
|
||||
csrf_token=self.generate_session_token(),
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
session.add(user_session)
|
||||
session.commit()
|
||||
|
@ -179,9 +201,19 @@ class UserManager:
|
|||
:param session_token: Session token.
|
||||
"""
|
||||
with self.db.get_session() as session:
|
||||
return session.query(User).join(UserSession).filter_by(session_token=session_token).first()
|
||||
return (
|
||||
session.query(User)
|
||||
.join(UserSession)
|
||||
.filter_by(session_token=session_token)
|
||||
.first()
|
||||
)
|
||||
|
||||
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str:
|
||||
def generate_jwt_token(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
expires_at: Optional[datetime.datetime] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a user JWT token for API usage.
|
||||
|
||||
|
@ -253,10 +285,10 @@ class UserManager:
|
|||
|
||||
|
||||
class User(Base):
|
||||
""" Models the User table """
|
||||
"""Models the User table"""
|
||||
|
||||
__tablename__ = 'user'
|
||||
__table_args__ = ({'sqlite_autoincrement': True})
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
user_id = Column(Integer, primary_key=True)
|
||||
username = Column(String, unique=True, nullable=False)
|
||||
|
@ -265,10 +297,10 @@ class User(Base):
|
|||
|
||||
|
||||
class UserSession(Base):
|
||||
""" Models the UserSession table """
|
||||
"""Models the UserSession table"""
|
||||
|
||||
__tablename__ = 'user_session'
|
||||
__table_args__ = ({'sqlite_autoincrement': True})
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
session_id = Column(Integer, primary_key=True)
|
||||
session_token = Column(String, unique=True, nullable=False)
|
||||
|
|
Loading…
Reference in a new issue