diff --git a/platypush/backend/covid19/__init__.py b/platypush/backend/covid19/__init__.py index 598f10871..1be1db016 100644 --- a/platypush/backend/covid19/__init__.py +++ b/platypush/backend/covid19/__init__.py @@ -3,8 +3,7 @@ import os from typing import Optional, Union, List, Dict, Any from sqlalchemy import create_engine, Column, Integer, String, DateTime -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base from platypush.backend import Backend from platypush.config import Config @@ -17,10 +16,10 @@ Session = scoped_session(sessionmaker()) class Covid19Update(Base): - """ Models the Covid19Data table """ + """Models the Covid19Data table""" __tablename__ = 'covid19data' - __table_args__ = ({'sqlite_autoincrement': True}) + __table_args__ = {'sqlite_autoincrement': True} country = Column(String, primary_key=True) confirmed = Column(Integer, nullable=False, default=0) @@ -40,7 +39,12 @@ class Covid19Backend(Backend): """ # noinspection PyProtectedMember - def __init__(self, country: Optional[Union[str, List[str]]], poll_seconds: Optional[float] = 3600.0, **kwargs): + def __init__( + self, + country: Optional[Union[str, List[str]]], + poll_seconds: Optional[float] = 3600.0, + **kwargs + ): """ :param country: Default country (or list of countries) to retrieve the stats for. It can either be the full country name or the country code. Special values: @@ -56,7 +60,9 @@ class Covid19Backend(Backend): super().__init__(poll_seconds=poll_seconds, **kwargs) self._plugin: Covid19Plugin = get_plugin('covid19') self.country: List[str] = self._plugin._get_countries(country) - self.workdir = os.path.join(os.path.expanduser(Config.get('workdir')), 'covid19') + self.workdir = os.path.join( + os.path.expanduser(Config.get('workdir')), 'covid19' + ) self.dbfile = os.path.join(self.workdir, 'data.db') os.makedirs(self.workdir, exist_ok=True) @@ -67,22 +73,30 @@ class Covid19Backend(Backend): self.logger.info('Stopped Covid19 backend') def _process_update(self, summary: Dict[str, Any], session: Session): - update_time = datetime.datetime.fromisoformat(summary['Date'].replace('Z', '+00:00')) + update_time = datetime.datetime.fromisoformat( + summary['Date'].replace('Z', '+00:00') + ) - self.bus.post(Covid19UpdateEvent( - country=summary['Country'], - country_code=summary['CountryCode'], - confirmed=summary['TotalConfirmed'], - deaths=summary['TotalDeaths'], - recovered=summary['TotalRecovered'], - update_time=update_time, - )) + self.bus.post( + Covid19UpdateEvent( + country=summary['Country'], + country_code=summary['CountryCode'], + confirmed=summary['TotalConfirmed'], + deaths=summary['TotalDeaths'], + recovered=summary['TotalRecovered'], + update_time=update_time, + ) + ) - session.merge(Covid19Update(country=summary['CountryCode'], - confirmed=summary['TotalConfirmed'], - deaths=summary['TotalDeaths'], - recovered=summary['TotalRecovered'], - last_updated_at=update_time)) + session.merge( + Covid19Update( + country=summary['CountryCode'], + confirmed=summary['TotalConfirmed'], + deaths=summary['TotalDeaths'], + recovered=summary['TotalRecovered'], + last_updated_at=update_time, + ) + ) def loop(self): # noinspection PyUnresolvedReferences @@ -90,23 +104,30 @@ class Covid19Backend(Backend): if not summaries: return - engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False}) + engine = create_engine( + 'sqlite:///{}'.format(self.dbfile), + connect_args={'check_same_thread': False}, + ) Base.metadata.create_all(engine) Session.configure(bind=engine) session = Session() last_records = { record.country: record - for record in session.query(Covid19Update).filter(Covid19Update.country.in_(self.country)).all() + for record in session.query(Covid19Update) + .filter(Covid19Update.country.in_(self.country)) + .all() } for summary in summaries: country = summary['CountryCode'] last_record = last_records.get(country) - if not last_record or \ - summary['TotalConfirmed'] != last_record.confirmed or \ - summary['TotalDeaths'] != last_record.deaths or \ - summary['TotalRecovered'] != last_record.recovered: + if ( + not last_record + or summary['TotalConfirmed'] != last_record.confirmed + or summary['TotalDeaths'] != last_record.deaths + or summary['TotalRecovered'] != last_record.recovered + ): self._process_update(summary=summary, session=session) session.commit() diff --git a/platypush/backend/github/__init__.py b/platypush/backend/github/__init__.py index ad49b73d4..0a1bc3e67 100644 --- a/platypush/backend/github/__init__.py +++ b/platypush/backend/github/__init__.py @@ -6,15 +6,28 @@ from typing import Optional, List import requests from sqlalchemy import create_engine, Column, String, DateTime -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, scoped_session +from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base from platypush.backend import Backend from platypush.config import Config -from platypush.message.event.github import GithubPushEvent, GithubCommitCommentEvent, GithubCreateEvent, \ - GithubDeleteEvent, GithubEvent, GithubForkEvent, GithubWikiEvent, GithubIssueCommentEvent, GithubIssueEvent, \ - GithubMemberEvent, GithubPublicEvent, GithubPullRequestEvent, GithubPullRequestReviewCommentEvent, \ - GithubReleaseEvent, GithubSponsorshipEvent, GithubWatchEvent +from platypush.message.event.github import ( + GithubPushEvent, + GithubCommitCommentEvent, + GithubCreateEvent, + GithubDeleteEvent, + GithubEvent, + GithubForkEvent, + GithubWikiEvent, + GithubIssueCommentEvent, + GithubIssueEvent, + GithubMemberEvent, + GithubPublicEvent, + GithubPullRequestEvent, + GithubPullRequestReviewCommentEvent, + GithubReleaseEvent, + GithubSponsorshipEvent, + GithubWatchEvent, +) Base = declarative_base() Session = scoped_session(sessionmaker()) @@ -71,8 +84,17 @@ class GithubBackend(Backend): _base_url = 'https://api.github.com' - def __init__(self, user: str, user_token: str, repos: Optional[List[str]] = None, org: Optional[str] = None, - poll_seconds: int = 60, max_events_per_scan: Optional[int] = 10, *args, **kwargs): + def __init__( + self, + user: str, + user_token: str, + repos: Optional[List[str]] = None, + org: Optional[str] = None, + poll_seconds: int = 60, + max_events_per_scan: Optional[int] = 10, + *args, + **kwargs + ): """ If neither ``repos`` nor ``org`` is specified then the backend will monitor all new events on user level. @@ -102,17 +124,23 @@ class GithubBackend(Backend): def _request(self, uri: str, method: str = 'get') -> dict: method = getattr(requests, method.lower()) - return method(self._base_url + uri, auth=(self.user, self.user_token), - headers={'Accept': 'application/vnd.github.v3+json'}).json() + return method( + self._base_url + uri, + auth=(self.user, self.user_token), + headers={'Accept': 'application/vnd.github.v3+json'}, + ).json() def _init_db(self): - engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False}) + engine = create_engine( + 'sqlite:///{}'.format(self.dbfile), + connect_args={'check_same_thread': False}, + ) Base.metadata.create_all(engine) Session.configure(bind=engine) @staticmethod def _to_datetime(time_string: str) -> datetime.datetime: - """ Convert ISO 8061 string format with leading 'Z' into something understandable by Python """ + """Convert ISO 8061 string format with leading 'Z' into something understandable by Python""" return datetime.datetime.fromisoformat(time_string[:-1] + '+00:00') @staticmethod @@ -128,7 +156,11 @@ class GithubBackend(Backend): def _get_last_event_time(self, uri: str): with self.db_lock: record = self._get_or_create_resource(uri=uri, session=Session()) - return record.last_updated_at.replace(tzinfo=datetime.timezone.utc) if record.last_updated_at else None + return ( + record.last_updated_at.replace(tzinfo=datetime.timezone.utc) + if record.last_updated_at + else None + ) def _update_last_event_time(self, uri: str, last_updated_at: datetime.datetime): with self.db_lock: @@ -158,9 +190,18 @@ class GithubBackend(Backend): 'WatchEvent': GithubWatchEvent, } - event_type = event_mapping[event['type']] if event['type'] in event_mapping else GithubEvent - return event_type(event_type=event['type'], actor=event['actor'], repo=event.get('repo', {}), - payload=event['payload'], created_at=cls._to_datetime(event['created_at'])) + event_type = ( + event_mapping[event['type']] + if event['type'] in event_mapping + else GithubEvent + ) + return event_type( + event_type=event['type'], + actor=event['actor'], + repo=event.get('repo', {}), + payload=event['payload'], + created_at=cls._to_datetime(event['created_at']), + ) def _events_monitor(self, uri: str, method: str = 'get'): def thread(): @@ -175,7 +216,10 @@ class GithubBackend(Backend): fired_events = [] for event in events: - if self.max_events_per_scan and len(fired_events) >= self.max_events_per_scan: + if ( + self.max_events_per_scan + and len(fired_events) >= self.max_events_per_scan + ): break event_time = self._to_datetime(event['created_at']) @@ -189,14 +233,19 @@ class GithubBackend(Backend): for event in fired_events: self.bus.post(event) - self._update_last_event_time(uri=uri, last_updated_at=new_last_event_time) + self._update_last_event_time( + uri=uri, last_updated_at=new_last_event_time + ) except Exception as e: - self.logger.warning('Encountered exception while fetching events from {}: {}'.format( - uri, str(e))) + self.logger.warning( + 'Encountered exception while fetching events from {}: {}'.format( + uri, str(e) + ) + ) self.logger.exception(e) - finally: - if self.wait_stop(timeout=self.poll_seconds): - break + + if self.wait_stop(timeout=self.poll_seconds): + break return thread @@ -206,12 +255,30 @@ class GithubBackend(Backend): if self.repos: for repo in self.repos: - monitors.append(threading.Thread(target=self._events_monitor('/networks/{repo}/events'.format(repo=repo)))) + monitors.append( + threading.Thread( + target=self._events_monitor( + '/networks/{repo}/events'.format(repo=repo) + ) + ) + ) if self.org: - monitors.append(threading.Thread(target=self._events_monitor('/orgs/{org}/events'.format(org=self.org)))) + monitors.append( + threading.Thread( + target=self._events_monitor( + '/orgs/{org}/events'.format(org=self.org) + ) + ) + ) if not (self.repos or self.org): - monitors.append(threading.Thread(target=self._events_monitor('/users/{user}/events'.format(user=self.user)))) + monitors.append( + threading.Thread( + target=self._events_monitor( + '/users/{user}/events'.format(user=self.user) + ) + ) + ) for monitor in monitors: monitor.start() @@ -222,4 +289,5 @@ class GithubBackend(Backend): self.logger.info('Github backend terminated') + # vim:sw=4:ts=4:et: diff --git a/platypush/backend/http/request/rss/__init__.py b/platypush/backend/http/request/rss/__init__.py index b16565dc5..7ca6d9c64 100644 --- a/platypush/backend/http/request/rss/__init__.py +++ b/platypush/backend/http/request/rss/__init__.py @@ -2,11 +2,17 @@ import datetime import enum import os -from sqlalchemy import create_engine, Column, Integer, String, DateTime, \ - Enum, ForeignKey +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + DateTime, + Enum, + ForeignKey, +) -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base from sqlalchemy.sql.expression import func from platypush.backend.http.request import HttpRequest @@ -44,18 +50,31 @@ class RssUpdates(HttpRequest): """ - user_agent = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) ' + \ - 'Chrome/62.0.3202.94 Safari/537.36' + user_agent = ( + 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) ' + + 'Chrome/62.0.3202.94 Safari/537.36' + ) - def __init__(self, url, title=None, headers=None, params=None, max_entries=None, - extract_content=False, digest_format=None, user_agent: str = user_agent, - body_style: str = 'font-size: 22px; ' + - 'font-family: "Merriweather", Georgia, "Times New Roman", Times, serif;', - title_style: str = 'margin-top: 30px', - subtitle_style: str = 'margin-top: 10px; page-break-after: always', - article_title_style: str = 'page-break-before: always', - article_link_style: str = 'color: #555; text-decoration: none; border-bottom: 1px dotted', - article_content_style: str = '', *argv, **kwargs): + def __init__( + self, + url, + title=None, + headers=None, + params=None, + max_entries=None, + extract_content=False, + digest_format=None, + user_agent: str = user_agent, + body_style: str = 'font-size: 22px; ' + + 'font-family: "Merriweather", Georgia, "Times New Roman", Times, serif;', + title_style: str = 'margin-top: 30px', + subtitle_style: str = 'margin-top: 10px; page-break-after: always', + article_title_style: str = 'page-break-before: always', + article_link_style: str = 'color: #555; text-decoration: none; border-bottom: 1px dotted', + article_content_style: str = '', + *argv, + **kwargs, + ): """ :param url: URL to the RSS feed to be monitored. :param title: Optional title for the feed. @@ -91,7 +110,9 @@ class RssUpdates(HttpRequest): # If true, then the http.webpage plugin will be used to parse the content self.extract_content = extract_content - self.digest_format = digest_format.lower() if digest_format else None # Supported formats: html, pdf + self.digest_format = ( + digest_format.lower() if digest_format else None + ) # Supported formats: html, pdf os.makedirs(os.path.expanduser(os.path.dirname(self.dbfile)), exist_ok=True) @@ -119,7 +140,11 @@ class RssUpdates(HttpRequest): @staticmethod def _get_latest_update(session, source_id): - return session.query(func.max(FeedEntry.published)).filter_by(source_id=source_id).scalar() + return ( + session.query(func.max(FeedEntry.published)) + .filter_by(source_id=source_id) + .scalar() + ) def _parse_entry_content(self, link): self.logger.info('Extracting content from {}'.format(link)) @@ -130,14 +155,20 @@ class RssUpdates(HttpRequest): errors = response.errors if not output: - self.logger.warning('Mercury parser error: {}'.format(errors or '[unknown error]')) + self.logger.warning( + 'Mercury parser error: {}'.format(errors or '[unknown error]') + ) return return output.get('content') def get_new_items(self, response): import feedparser - engine = create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False}) + + engine = create_engine( + 'sqlite:///{}'.format(self.dbfile), + connect_args={'check_same_thread': False}, + ) Base.metadata.create_all(engine) Session.configure(bind=engine) @@ -157,12 +188,16 @@ class RssUpdates(HttpRequest): content = u'''

{title}

-

Feeds digest generated on {creation_date}

'''.\ - format(title_style=self.title_style, title=self.title, subtitle_style=self.subtitle_style, - creation_date=datetime.datetime.now().strftime('%d %B %Y, %H:%M')) +

Feeds digest generated on {creation_date}

'''.format( + title_style=self.title_style, + title=self.title, + subtitle_style=self.subtitle_style, + creation_date=datetime.datetime.now().strftime('%d %B %Y, %H:%M'), + ) - self.logger.info('Parsed {:d} items from RSS feed <{}>' - .format(len(feed.entries), self.url)) + self.logger.info( + 'Parsed {:d} items from RSS feed <{}>'.format(len(feed.entries), self.url) + ) for entry in feed.entries: if not entry.published_parsed: @@ -171,9 +206,10 @@ class RssUpdates(HttpRequest): try: entry_timestamp = datetime.datetime(*entry.published_parsed[:6]) - if latest_update is None \ - or entry_timestamp > latest_update: - self.logger.info('Processed new item from RSS feed <{}>'.format(self.url)) + if latest_update is None or entry_timestamp > latest_update: + self.logger.info( + 'Processed new item from RSS feed <{}>'.format(self.url) + ) entry.summary = entry.summary if hasattr(entry, 'summary') else None if self.extract_content: @@ -188,9 +224,13 @@ class RssUpdates(HttpRequest): {title}
{content}
'''.format( - article_title_style=self.article_title_style, article_link_style=self.article_link_style, - article_content_style=self.article_content_style, link=entry.link, title=entry.title, - content=entry.content) + article_title_style=self.article_title_style, + article_link_style=self.article_link_style, + article_content_style=self.article_content_style, + link=entry.link, + title=entry.title, + content=entry.content, + ) e = { 'entry_id': entry.id, @@ -207,21 +247,32 @@ class RssUpdates(HttpRequest): if self.max_entries and len(entries) > self.max_entries: break except Exception as e: - self.logger.warning('Exception encountered while parsing RSS ' + - 'RSS feed {}: {}'.format(entry.link, str(e))) + self.logger.warning( + 'Exception encountered while parsing RSS ' + + f'RSS feed {entry.link}: {e}' + ) self.logger.exception(e) source_record.last_updated_at = parse_start_time digest_filename = None if entries: - self.logger.info('Parsed {} new entries from the RSS feed {}'.format( - len(entries), self.title)) + self.logger.info( + 'Parsed {} new entries from the RSS feed {}'.format( + len(entries), self.title + ) + ) if self.digest_format: - digest_filename = os.path.join(self.workdir, 'cache', '{}_{}.{}'.format( - datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'), - self.title, self.digest_format)) + digest_filename = os.path.join( + self.workdir, + 'cache', + '{}_{}.{}'.format( + datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'), + self.title, + self.digest_format, + ), + ) os.makedirs(os.path.dirname(digest_filename), exist_ok=True) @@ -233,12 +284,15 @@ class RssUpdates(HttpRequest): {content} - '''.format(title=self.title, body_style=self.body_style, content=content) + '''.format( + title=self.title, body_style=self.body_style, content=content + ) with open(digest_filename, 'w', encoding='utf-8') as f: f.write(content) elif self.digest_format == 'pdf': from weasyprint import HTML, CSS + try: from weasyprint.fonts import FontConfiguration except ImportError: @@ -246,37 +300,47 @@ class RssUpdates(HttpRequest): body_style = 'body { ' + self.body_style + ' }' font_config = FontConfiguration() - css = [CSS('https://fonts.googleapis.com/css?family=Merriweather'), - CSS(string=body_style, font_config=font_config)] + css = [ + CSS('https://fonts.googleapis.com/css?family=Merriweather'), + CSS(string=body_style, font_config=font_config), + ] HTML(string=content).write_pdf(digest_filename, stylesheets=css) else: - raise RuntimeError('Unsupported format: {}. Supported formats: ' + - 'html or pdf'.format(self.digest_format)) + raise RuntimeError( + f'Unsupported format: {self.digest_format}. Supported formats: html, pdf' + ) - digest_entry = FeedDigest(source_id=source_record.id, - format=self.digest_format, - filename=digest_filename) + digest_entry = FeedDigest( + source_id=source_record.id, + format=self.digest_format, + filename=digest_filename, + ) session.add(digest_entry) - self.logger.info('{} digest ready: {}'.format(self.digest_format, digest_filename)) + self.logger.info( + '{} digest ready: {}'.format(self.digest_format, digest_filename) + ) session.commit() self.logger.info('Parsing RSS feed {}: completed'.format(self.title)) - return NewFeedEvent(request=dict(self), response=entries, - source_id=source_record.id, - source_title=source_record.title, - title=self.title, - digest_format=self.digest_format, - digest_filename=digest_filename) + return NewFeedEvent( + request=dict(self), + response=entries, + source_id=source_record.id, + source_title=source_record.title, + title=self.title, + digest_format=self.digest_format, + digest_filename=digest_filename, + ) class FeedSource(Base): - """ Models the FeedSource table, containing RSS sources to be parsed """ + """Models the FeedSource table, containing RSS sources to be parsed""" __tablename__ = 'FeedSource' - __table_args__ = ({'sqlite_autoincrement': True}) + __table_args__ = {'sqlite_autoincrement': True} id = Column(Integer, primary_key=True) title = Column(String) @@ -285,10 +349,10 @@ class FeedSource(Base): class FeedEntry(Base): - """ Models the FeedEntry table, which contains RSS entries """ + """Models the FeedEntry table, which contains RSS entries""" __tablename__ = 'FeedEntry' - __table_args__ = ({'sqlite_autoincrement': True}) + __table_args__ = {'sqlite_autoincrement': True} id = Column(Integer, primary_key=True) entry_id = Column(String) @@ -301,15 +365,15 @@ class FeedEntry(Base): class FeedDigest(Base): - """ Models the FeedDigest table, containing feed digests either in HTML - or PDF format """ + """Models the FeedDigest table, containing feed digests either in HTML + or PDF format""" class DigestFormat(enum.Enum): html = 1 pdf = 2 __tablename__ = 'FeedDigest' - __table_args__ = ({'sqlite_autoincrement': True}) + __table_args__ = {'sqlite_autoincrement': True} id = Column(Integer, primary_key=True) source_id = Column(Integer, ForeignKey('FeedSource.id'), nullable=False) @@ -317,4 +381,5 @@ class FeedDigest(Base): filename = Column(String, nullable=False) created_at = Column(DateTime, nullable=False, default=datetime.datetime.utcnow) + # vim:sw=4:ts=4:et: diff --git a/platypush/backend/mail/__init__.py b/platypush/backend/mail/__init__.py index 686075286..84f1f8ab6 100644 --- a/platypush/backend/mail/__init__.py +++ b/platypush/backend/mail/__init__.py @@ -8,15 +8,18 @@ from queue import Queue, Empty from threading import Thread, RLock from typing import List, Dict, Any, Optional, Tuple -from sqlalchemy import create_engine, Column, Integer, String, DateTime -import sqlalchemy.engine as engine -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import engine, create_engine, Column, Integer, String, DateTime +from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base from platypush.backend import Backend from platypush.config import Config from platypush.context import get_plugin -from platypush.message.event.mail import MailReceivedEvent, MailSeenEvent, MailFlaggedEvent, MailUnflaggedEvent +from platypush.message.event.mail import ( + MailReceivedEvent, + MailSeenEvent, + MailFlaggedEvent, + MailUnflaggedEvent, +) from platypush.plugins.mail import MailInPlugin, Mail # @@ -25,7 +28,8 @@ Session = scoped_session(sessionmaker()) class MailboxStatus(Base): - """ Models the MailboxStatus table, containing information about the state of a monitored mailbox. """ + """Models the MailboxStatus table, containing information about the state of a monitored mailbox.""" + __tablename__ = 'MailboxStatus' mailbox_id = Column(Integer, primary_key=True) @@ -64,8 +68,13 @@ class MailBackend(Backend): """ - def __init__(self, mailboxes: List[Dict[str, Any]], timeout: Optional[int] = 60, poll_seconds: Optional[int] = 60, - **kwargs): + def __init__( + self, + mailboxes: List[Dict[str, Any]], + timeout: Optional[int] = 60, + poll_seconds: Optional[int] = 60, + **kwargs + ): """ :param mailboxes: List of mailboxes to be monitored. Each mailbox entry contains a ``plugin`` attribute to identify the :class:`platypush.plugins.mail.MailInPlugin` plugin that will be used (e.g. ``mail.imap``) @@ -128,9 +137,13 @@ class MailBackend(Backend): # Parse mailboxes for i, mbox in enumerate(mailboxes): - assert 'plugin' in mbox, 'No plugin attribute specified for mailbox n.{}'.format(i) + assert ( + 'plugin' in mbox + ), 'No plugin attribute specified for mailbox n.{}'.format(i) plugin = get_plugin(mbox.pop('plugin')) - assert isinstance(plugin, MailInPlugin), '{} is not a MailInPlugin'.format(plugin) + assert isinstance(plugin, MailInPlugin), '{} is not a MailInPlugin'.format( + plugin + ) name = mbox.pop('name') if 'name' in mbox else 'Mailbox #{}'.format(i + 1) self.mailboxes.append(Mailbox(plugin=plugin, name=name, args=mbox)) @@ -144,7 +157,10 @@ class MailBackend(Backend): # def _db_get_engine(self) -> engine.Engine: - return create_engine('sqlite:///{}'.format(self.dbfile), connect_args={'check_same_thread': False}) + return create_engine( + 'sqlite:///{}'.format(self.dbfile), + connect_args={'check_same_thread': False}, + ) def _db_load_mailboxes_status(self) -> None: mailbox_ids = list(range(len(self.mailboxes))) @@ -153,12 +169,18 @@ class MailBackend(Backend): session = Session() records = { record.mailbox_id: record - for record in session.query(MailboxStatus).filter(MailboxStatus.mailbox_id.in_(mailbox_ids)).all() + for record in session.query(MailboxStatus) + .filter(MailboxStatus.mailbox_id.in_(mailbox_ids)) + .all() } - for mbox_id, mbox in enumerate(self.mailboxes): + for mbox_id, _ in enumerate(self.mailboxes): if mbox_id not in records: - record = MailboxStatus(mailbox_id=mbox_id, unseen_message_ids='[]', flagged_message_ids='[]') + record = MailboxStatus( + mailbox_id=mbox_id, + unseen_message_ids='[]', + flagged_message_ids='[]', + ) session.add(record) else: record = records[mbox_id] @@ -170,19 +192,25 @@ class MailBackend(Backend): session.commit() - def _db_get_mailbox_status(self, mailbox_ids: List[int]) -> Dict[int, MailboxStatus]: + def _db_get_mailbox_status( + self, mailbox_ids: List[int] + ) -> Dict[int, MailboxStatus]: with self._db_lock: session = Session() return { record.mailbox_id: record - for record in session.query(MailboxStatus).filter(MailboxStatus.mailbox_id.in_(mailbox_ids)).all() + for record in session.query(MailboxStatus) + .filter(MailboxStatus.mailbox_id.in_(mailbox_ids)) + .all() } # # @staticmethod - def _check_thread(unread_queue: Queue, flagged_queue: Queue, plugin: MailInPlugin, **args): + def _check_thread( + unread_queue: Queue, flagged_queue: Queue, plugin: MailInPlugin, **args + ): def thread(): # noinspection PyUnresolvedReferences unread = plugin.search_unseen_messages(**args).output @@ -194,8 +222,9 @@ class MailBackend(Backend): return thread - def _get_unread_seen_msgs(self, mailbox_idx: int, unread_msgs: Dict[int, Mail]) \ - -> Tuple[Dict[int, Mail], Dict[int, Mail]]: + def _get_unread_seen_msgs( + self, mailbox_idx: int, unread_msgs: Dict[int, Mail] + ) -> Tuple[Dict[int, Mail], Dict[int, Mail]]: prev_unread_msgs = self._unread_msgs[mailbox_idx] return { @@ -208,35 +237,51 @@ class MailBackend(Backend): if msg_id not in unread_msgs } - def _get_flagged_unflagged_msgs(self, mailbox_idx: int, flagged_msgs: Dict[int, Mail]) \ - -> Tuple[Dict[int, Mail], Dict[int, Mail]]: + def _get_flagged_unflagged_msgs( + self, mailbox_idx: int, flagged_msgs: Dict[int, Mail] + ) -> Tuple[Dict[int, Mail], Dict[int, Mail]]: prev_flagged_msgs = self._flagged_msgs[mailbox_idx] return { - msg_id: flagged_msgs[msg_id] - for msg_id in flagged_msgs - if msg_id not in prev_flagged_msgs - }, { - msg_id: prev_flagged_msgs[msg_id] - for msg_id in prev_flagged_msgs - if msg_id not in flagged_msgs - } + msg_id: flagged_msgs[msg_id] + for msg_id in flagged_msgs + if msg_id not in prev_flagged_msgs + }, { + msg_id: prev_flagged_msgs[msg_id] + for msg_id in prev_flagged_msgs + if msg_id not in flagged_msgs + } - def _process_msg_events(self, mailbox_id: int, unread: List[Mail], seen: List[Mail], - flagged: List[Mail], unflagged: List[Mail], last_checked_date: Optional[datetime] = None): + def _process_msg_events( + self, + mailbox_id: int, + unread: List[Mail], + seen: List[Mail], + flagged: List[Mail], + unflagged: List[Mail], + last_checked_date: Optional[datetime] = None, + ): for msg in unread: if msg.date and last_checked_date and msg.date < last_checked_date: continue - self.bus.post(MailReceivedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)) + self.bus.post( + MailReceivedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg) + ) for msg in seen: - self.bus.post(MailSeenEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)) + self.bus.post( + MailSeenEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg) + ) for msg in flagged: - self.bus.post(MailFlaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)) + self.bus.post( + MailFlaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg) + ) for msg in unflagged: - self.bus.post(MailUnflaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg)) + self.bus.post( + MailUnflaggedEvent(mailbox=self.mailboxes[mailbox_id].name, message=msg) + ) def _check_mailboxes(self) -> List[Tuple[Dict[int, Mail], Dict[int, Mail]]]: workers = [] @@ -245,8 +290,14 @@ class MailBackend(Backend): for mbox in self.mailboxes: unread_queue, flagged_queue = [Queue()] * 2 - worker = Thread(target=self._check_thread(unread_queue=unread_queue, flagged_queue=flagged_queue, - plugin=mbox.plugin, **mbox.args)) + worker = Thread( + target=self._check_thread( + unread_queue=unread_queue, + flagged_queue=flagged_queue, + plugin=mbox.plugin, + **mbox.args + ) + ) worker.start() workers.append(worker) queues.append((unread_queue, flagged_queue)) @@ -260,7 +311,11 @@ class MailBackend(Backend): flagged = flagged_queue.get(timeout=self.timeout) results.append((unread, flagged)) except Empty: - self.logger.warning('Checks on mailbox #{} timed out after {} seconds'.format(i + 1, self.timeout)) + self.logger.warning( + 'Checks on mailbox #{} timed out after {} seconds'.format( + i + 1, self.timeout + ) + ) continue return results @@ -276,16 +331,25 @@ class MailBackend(Backend): for i, (unread, flagged) in enumerate(results): unread_msgs, seen_msgs = self._get_unread_seen_msgs(i, unread) flagged_msgs, unflagged_msgs = self._get_flagged_unflagged_msgs(i, flagged) - self._process_msg_events(i, unread=list(unread_msgs.values()), seen=list(seen_msgs.values()), - flagged=list(flagged_msgs.values()), unflagged=list(unflagged_msgs.values()), - last_checked_date=mailbox_statuses[i].last_checked_date) + self._process_msg_events( + i, + unread=list(unread_msgs.values()), + seen=list(seen_msgs.values()), + flagged=list(flagged_msgs.values()), + unflagged=list(unflagged_msgs.values()), + last_checked_date=mailbox_statuses[i].last_checked_date, + ) self._unread_msgs[i] = unread self._flagged_msgs[i] = flagged - records.append(MailboxStatus(mailbox_id=i, - unseen_message_ids=json.dumps([msg_id for msg_id in unread.keys()]), - flagged_message_ids=json.dumps([msg_id for msg_id in flagged.keys()]), - last_checked_date=datetime.now())) + records.append( + MailboxStatus( + mailbox_id=i, + unseen_message_ids=json.dumps(list(unread.keys())), + flagged_message_ids=json.dumps(list(flagged.keys())), + last_checked_date=datetime.now(), + ) + ) with self._db_lock: session = Session() diff --git a/platypush/entities/_base.py b/platypush/entities/_base.py index 80be23b83..fb38fa460 100644 --- a/platypush/entities/_base.py +++ b/platypush/entities/_base.py @@ -5,7 +5,7 @@ from typing import Mapping, Type import pkgutil from sqlalchemy import Column, Index, Integer, String, DateTime, JSON, UniqueConstraint -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import declarative_base Base = declarative_base() entities_registry: Mapping[Type['Entity'], Mapping] = {} @@ -24,14 +24,16 @@ class Entity(Base): type = Column(String, nullable=False, index=True) plugin = Column(String, nullable=False) data = Column(JSON, default=dict) - created_at = Column(DateTime(timezone=False), default=datetime.utcnow(), nullable=False) - updated_at = Column(DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now()) + created_at = Column( + DateTime(timezone=False), default=datetime.utcnow(), nullable=False + ) + updated_at = Column( + DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.now() + ) UniqueConstraint(external_id, plugin) - __table_args__ = ( - Index(name, plugin), - ) + __table_args__ = (Index(name, plugin),) __mapper_args__ = { 'polymorphic_identity': __tablename__, @@ -41,13 +43,14 @@ class Entity(Base): def _discover_entity_types(): from platypush.context import get_plugin + logger = get_plugin('logger') assert logger for loader, modname, _ in pkgutil.walk_packages( path=[str(pathlib.Path(__file__).parent.absolute())], prefix=__package__ + '.', - onerror=lambda _: None + onerror=lambda _: None, ): try: mod_loader = loader.find_module(modname) # type: ignore @@ -65,9 +68,9 @@ def _discover_entity_types(): def init_entities_db(): from platypush.context import get_plugin + _discover_entity_types() db = get_plugin('db') assert db engine = db.get_engine() db.create_all(engine, Base) - diff --git a/platypush/plugins/media/search/local.py b/platypush/plugins/media/search/local.py index 3298dfe17..76868aa4f 100644 --- a/platypush/plugins/media/search/local.py +++ b/platypush/plugins/media/search/local.py @@ -3,9 +3,16 @@ import os import re import time -from sqlalchemy import create_engine, Column, Integer, String, DateTime, PrimaryKeyConstraint, ForeignKey -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + DateTime, + PrimaryKeyConstraint, + ForeignKey, +) +from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base from sqlalchemy.sql.expression import func from platypush.config import Config @@ -38,7 +45,8 @@ class LocalMediaSearcher(MediaSearcher): if not self._db_engine: self._db_engine = create_engine( 'sqlite:///{}'.format(self.db_file), - connect_args={'check_same_thread': False}) + connect_args={'check_same_thread': False}, + ) Base.metadata.create_all(self._db_engine) Session.configure(bind=self._db_engine) @@ -57,27 +65,30 @@ class LocalMediaSearcher(MediaSearcher): @classmethod def _get_last_modify_time(cls, path, recursive=False): - return max([os.path.getmtime(p) for p, _, _ in os.walk(path)]) \ - if recursive else os.path.getmtime(path) + return ( + max([os.path.getmtime(p) for p, _, _ in os.walk(path)]) + if recursive + else os.path.getmtime(path) + ) @classmethod - def _has_directory_changed_since_last_indexing(self, dir_record): + def _has_directory_changed_since_last_indexing(cls, dir_record): if not dir_record.last_indexed_at: return True - return datetime.datetime.fromtimestamp( - self._get_last_modify_time(dir_record.path)) > dir_record.last_indexed_at + return ( + datetime.datetime.fromtimestamp(cls._get_last_modify_time(dir_record.path)) + > dir_record.last_indexed_at + ) @classmethod def _matches_query(cls, filename, query): filename = filename.lower() - query_tokens = [_.lower() for _ in re.split( - cls._filename_separators, query.strip())] + query_tokens = [ + _.lower() for _ in re.split(cls._filename_separators, query.strip()) + ] - for token in query_tokens: - if token not in filename: - return False - return True + return all(token in filename for token in query_tokens) @classmethod def _sync_token_records(cls, session, *tokens): @@ -85,9 +96,12 @@ class LocalMediaSearcher(MediaSearcher): if not tokens: return [] - records = {record.token: record for record in - session.query(MediaToken).filter( - MediaToken.token.in_(tokens)).all()} + records = { + record.token: record + for record in session.query(MediaToken) + .filter(MediaToken.token.in_(tokens)) + .all() + } for token in tokens: if token in records: @@ -97,13 +111,11 @@ class LocalMediaSearcher(MediaSearcher): records[token] = record session.commit() - return session.query(MediaToken).filter( - MediaToken.token.in_(tokens)).all() + return session.query(MediaToken).filter(MediaToken.token.in_(tokens)).all() @classmethod def _get_file_records(cls, dir_record, session): - return session.query(MediaFile).filter_by( - directory_id=dir_record.id).all() + return session.query(MediaFile).filter_by(directory_id=dir_record.id).all() def scan(self, media_dir, session=None, dir_record=None): """ @@ -121,17 +133,19 @@ class LocalMediaSearcher(MediaSearcher): dir_record = self._get_or_create_dir_entry(session, media_dir) if not os.path.isdir(media_dir): - self.logger.info('Directory {} is no longer accessible, removing it'. - format(media_dir)) - session.query(MediaDirectory) \ - .filter(MediaDirectory.path == media_dir) \ - .delete(synchronize_session='fetch') + self.logger.info( + 'Directory {} is no longer accessible, removing it'.format(media_dir) + ) + session.query(MediaDirectory).filter( + MediaDirectory.path == media_dir + ).delete(synchronize_session='fetch') return stored_file_records = { - f.path: f for f in self._get_file_records(dir_record, session)} + f.path: f for f in self._get_file_records(dir_record, session) + } - for path, dirs, files in os.walk(media_dir): + for path, _, files in os.walk(media_dir): for filename in files: filepath = os.path.join(path, filename) @@ -142,26 +156,32 @@ class LocalMediaSearcher(MediaSearcher): del stored_file_records[filepath] continue - if not MediaPlugin.is_video_file(filename) and \ - not MediaPlugin.is_audio_file(filename): + if not MediaPlugin.is_video_file( + filename + ) and not MediaPlugin.is_audio_file(filename): continue self.logger.debug('Syncing item {}'.format(filepath)) - tokens = [_.lower() for _ in re.split(self._filename_separators, - filename.strip())] + tokens = [ + _.lower() + for _ in re.split(self._filename_separators, filename.strip()) + ] token_records = self._sync_token_records(session, *tokens) - file_record = MediaFile.build(directory_id=dir_record.id, - path=filepath) + file_record = MediaFile.build(directory_id=dir_record.id, path=filepath) session.add(file_record) session.commit() - file_record = session.query(MediaFile).filter_by( - directory_id=dir_record.id, path=filepath).one() + file_record = ( + session.query(MediaFile) + .filter_by(directory_id=dir_record.id, path=filepath) + .one() + ) for token_record in token_records: - file_token = MediaFileToken.build(file_id=file_record.id, - token_id=token_record.id) + file_token = MediaFileToken.build( + file_id=file_record.id, token_id=token_record.id + ) session.add(file_token) # stored_file_records should now only contain the records of the files @@ -169,15 +189,20 @@ class LocalMediaSearcher(MediaSearcher): if stored_file_records: self.logger.info( 'Removing references to {} deleted media items from {}'.format( - len(stored_file_records), media_dir)) + len(stored_file_records), media_dir + ) + ) - session.query(MediaFile).filter(MediaFile.id.in_( - [record.id for record in stored_file_records.values()] - )).delete(synchronize_session='fetch') + session.query(MediaFile).filter( + MediaFile.id.in_([record.id for record in stored_file_records.values()]) + ).delete(synchronize_session='fetch') dir_record.last_indexed_at = datetime.datetime.now() - self.logger.info('Scanned {} in {} seconds'.format( - media_dir, int(time.time() - index_start_time))) + self.logger.info( + 'Scanned {} in {} seconds'.format( + media_dir, int(time.time() - index_start_time) + ) + ) session.commit() @@ -197,25 +222,30 @@ class LocalMediaSearcher(MediaSearcher): dir_record = self._get_or_create_dir_entry(session, media_dir) if self._has_directory_changed_since_last_indexing(dir_record): - self.logger.info('{} has changed since last indexing, '.format( - media_dir) + 're-indexing') + self.logger.info( + '{} has changed since last indexing, '.format(media_dir) + + 're-indexing' + ) self.scan(media_dir, session=session, dir_record=dir_record) - query_tokens = [_.lower() for _ in re.split( - self._filename_separators, query.strip())] + query_tokens = [ + _.lower() for _ in re.split(self._filename_separators, query.strip()) + ] - for file_record in session.query(MediaFile.path). \ - join(MediaFileToken). \ - join(MediaToken). \ - filter(MediaToken.token.in_(query_tokens)). \ - group_by(MediaFile.path). \ - having(func.count(MediaFileToken.token_id) >= len(query_tokens)): + for file_record in ( + session.query(MediaFile.path) + .join(MediaFileToken) + .join(MediaToken) + .filter(MediaToken.token.in_(query_tokens)) + .group_by(MediaFile.path) + .having(func.count(MediaFileToken.token_id) >= len(query_tokens)) + ): if os.path.isfile(file_record.path): results[file_record.path] = { 'url': 'file://' + file_record.path, 'title': os.path.basename(file_record.path), - 'size': os.path.getsize(file_record.path) + 'size': os.path.getsize(file_record.path), } return results.values() @@ -223,11 +253,12 @@ class LocalMediaSearcher(MediaSearcher): # --- Table definitions + class MediaDirectory(Base): - """ Models the MediaDirectory table """ + """Models the MediaDirectory table""" __tablename__ = 'MediaDirectory' - __table_args__ = ({'sqlite_autoincrement': True}) + __table_args__ = {'sqlite_autoincrement': True} id = Column(Integer, primary_key=True) path = Column(String) @@ -243,14 +274,15 @@ class MediaDirectory(Base): class MediaFile(Base): - """ Models the MediaFile table """ + """Models the MediaFile table""" __tablename__ = 'MediaFile' - __table_args__ = ({'sqlite_autoincrement': True}) + __table_args__ = {'sqlite_autoincrement': True} id = Column(Integer, primary_key=True) - directory_id = Column(Integer, ForeignKey( - 'MediaDirectory.id', ondelete='CASCADE'), nullable=False) + directory_id = Column( + Integer, ForeignKey('MediaDirectory.id', ondelete='CASCADE'), nullable=False + ) path = Column(String, nullable=False, unique=True) indexed_at = Column(DateTime) @@ -265,10 +297,10 @@ class MediaFile(Base): class MediaToken(Base): - """ Models the MediaToken table """ + """Models the MediaToken table""" __tablename__ = 'MediaToken' - __table_args__ = ({'sqlite_autoincrement': True}) + __table_args__ = {'sqlite_autoincrement': True} id = Column(Integer, primary_key=True) token = Column(String, nullable=False, unique=True) @@ -282,14 +314,16 @@ class MediaToken(Base): class MediaFileToken(Base): - """ Models the MediaFileToken table """ + """Models the MediaFileToken table""" __tablename__ = 'MediaFileToken' - file_id = Column(Integer, ForeignKey('MediaFile.id', ondelete='CASCADE'), - nullable=False) - token_id = Column(Integer, ForeignKey('MediaToken.id', ondelete='CASCADE'), - nullable=False) + file_id = Column( + Integer, ForeignKey('MediaFile.id', ondelete='CASCADE'), nullable=False + ) + token_id = Column( + Integer, ForeignKey('MediaToken.id', ondelete='CASCADE'), nullable=False + ) __table_args__ = (PrimaryKeyConstraint(file_id, token_id), {}) @@ -301,4 +335,5 @@ class MediaFileToken(Base): record.token_id = token_id return record + # vim:sw=4:ts=4:et: diff --git a/platypush/user/__init__.py b/platypush/user/__init__.py index 673497b7f..c6c2bcd22 100644 --- a/platypush/user/__init__.py +++ b/platypush/user/__init__.py @@ -13,11 +13,13 @@ except ImportError: from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode from sqlalchemy import Column, Integer, String, DateTime, ForeignKey -from sqlalchemy.orm import make_transient -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import make_transient, declarative_base from platypush.context import get_plugin -from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException +from platypush.exceptions.user import ( + InvalidJWTTokenException, + InvalidCredentialsException, +) from platypush.utils import get_or_generate_jwt_rsa_key_pair Base = declarative_base() @@ -68,8 +70,12 @@ class UserManager: if user: raise NameError('The user {} already exists'.format(username)) - record = User(username=username, password=self._encrypt_password(password), - created_at=datetime.datetime.utcnow(), **kwargs) + record = User( + username=username, + password=self._encrypt_password(password), + created_at=datetime.datetime.utcnow(), + **kwargs + ) session.add(record) session.commit() @@ -93,10 +99,16 @@ class UserManager: def authenticate_user_session(self, session_token): with self.db.get_session() as session: - user_session = session.query(UserSession).filter_by(session_token=session_token).first() + user_session = ( + session.query(UserSession) + .filter_by(session_token=session_token) + .first() + ) if not user_session or ( - user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()): + user_session.expires_at + and user_session.expires_at < datetime.datetime.utcnow() + ): return None, None user = session.query(User).filter_by(user_id=user_session.user_id).first() @@ -108,7 +120,9 @@ class UserManager: if not user: raise NameError('No such user: {}'.format(username)) - user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all() + user_sessions = ( + session.query(UserSession).filter_by(user_id=user.user_id).all() + ) for user_session in user_sessions: session.delete(user_session) @@ -118,7 +132,11 @@ class UserManager: def delete_user_session(self, session_token): with self.db.get_session() as session: - user_session = session.query(UserSession).filter_by(session_token=session_token).first() + user_session = ( + session.query(UserSession) + .filter_by(session_token=session_token) + .first() + ) if not user_session: return False @@ -134,14 +152,18 @@ class UserManager: return None if expires_at: - if isinstance(expires_at, int) or isinstance(expires_at, float): + if isinstance(expires_at, (int, float)): expires_at = datetime.datetime.fromtimestamp(expires_at) elif isinstance(expires_at, str): expires_at = datetime.datetime.fromisoformat(expires_at) - user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(), - csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(), - expires_at=expires_at) + user_session = UserSession( + user_id=user.user_id, + session_token=self.generate_session_token(), + csrf_token=self.generate_session_token(), + created_at=datetime.datetime.utcnow(), + expires_at=expires_at, + ) session.add(user_session) session.commit() @@ -179,9 +201,19 @@ class UserManager: :param session_token: Session token. """ with self.db.get_session() as session: - return session.query(User).join(UserSession).filter_by(session_token=session_token).first() + return ( + session.query(User) + .join(UserSession) + .filter_by(session_token=session_token) + .first() + ) - def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str: + def generate_jwt_token( + self, + username: str, + password: str, + expires_at: Optional[datetime.datetime] = None, + ) -> str: """ Create a user JWT token for API usage. @@ -253,10 +285,10 @@ class UserManager: class User(Base): - """ Models the User table """ + """Models the User table""" __tablename__ = 'user' - __table_args__ = ({'sqlite_autoincrement': True}) + __table_args__ = {'sqlite_autoincrement': True} user_id = Column(Integer, primary_key=True) username = Column(String, unique=True, nullable=False) @@ -265,10 +297,10 @@ class User(Base): class UserSession(Base): - """ Models the UserSession table """ + """Models the UserSession table""" __tablename__ = 'user_session' - __table_args__ = ({'sqlite_autoincrement': True}) + __table_args__ = {'sqlite_autoincrement': True} session_id = Column(Integer, primary_key=True) session_token = Column(String, unique=True, nullable=False)