forked from platypush/platypush
Large refactor of the entities engine.
This commit is contained in:
parent
9ddebb920f
commit
b0464219d3
9 changed files with 757 additions and 360 deletions
|
@ -1,19 +1,21 @@
|
|||
import inspect
|
||||
import json
|
||||
import pathlib
|
||||
import types
|
||||
from datetime import datetime
|
||||
from typing import Dict, Mapping, Type, Tuple, Any
|
||||
from dateutil.tz import tzutc
|
||||
from typing import Mapping, Type, Tuple, Any
|
||||
|
||||
import pkgutil
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
JSON,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
inspect as schema_inspect,
|
||||
)
|
||||
|
@ -23,10 +25,11 @@ from platypush.common.db import Base
|
|||
from platypush.message import JSONAble
|
||||
|
||||
entities_registry: Mapping[Type['Entity'], Mapping] = {}
|
||||
entity_types_registry: Dict[str, Type['Entity']] = {}
|
||||
|
||||
|
||||
class Entity(Base):
|
||||
if 'entity' not in Base.metadata:
|
||||
|
||||
class Entity(Base):
|
||||
"""
|
||||
Model for a general-purpose platform entity.
|
||||
"""
|
||||
|
@ -34,7 +37,7 @@ class Entity(Base):
|
|||
__tablename__ = 'entity'
|
||||
|
||||
id = Column(Integer, autoincrement=True, primary_key=True)
|
||||
external_id = Column(String, nullable=True)
|
||||
external_id = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False, index=True)
|
||||
description = Column(String)
|
||||
type = Column(String, nullable=False, index=True)
|
||||
|
@ -54,7 +57,9 @@ class Entity(Base):
|
|||
DateTime(timezone=False), default=datetime.utcnow(), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=False), default=datetime.utcnow(), onupdate=datetime.utcnow()
|
||||
DateTime(timezone=False),
|
||||
default=datetime.utcnow(),
|
||||
onupdate=datetime.utcnow(),
|
||||
)
|
||||
|
||||
parent: Mapped['Entity'] = relationship(
|
||||
|
@ -62,6 +67,7 @@ class Entity(Base):
|
|||
remote_side=[id],
|
||||
uselist=False,
|
||||
lazy=True,
|
||||
post_update=True,
|
||||
backref=backref(
|
||||
'children',
|
||||
remote_side=[parent_id],
|
||||
|
@ -89,17 +95,50 @@ class Entity(Base):
|
|||
inspector = schema_inspect(cls)
|
||||
return tuple(inspector.mapper.column_attrs)
|
||||
|
||||
@property
|
||||
def entity_key(self) -> Tuple[str, str]:
|
||||
"""
|
||||
This method returns the "external" key of an entity.
|
||||
"""
|
||||
return (str(self.external_id), str(self.plugin))
|
||||
|
||||
def _serialize_value(self, col: ColumnProperty) -> Any:
|
||||
val = getattr(self, col.key)
|
||||
if isinstance(val, datetime):
|
||||
# All entity timestamps are in UTC
|
||||
val = val.isoformat() + '+00:00'
|
||||
val = val.replace(tzinfo=tzutc()).isoformat()
|
||||
|
||||
return val
|
||||
|
||||
def copy(self) -> 'Entity':
|
||||
args = {c.key: getattr(self, c.key) for c in self.columns}
|
||||
# if self.parent:
|
||||
# args['parent'] = self.parent.copy()
|
||||
|
||||
# args['children'] = [c.copy() for c in self.children]
|
||||
return self.__class__(**args)
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {col.key: self._serialize_value(col) for col in self.columns}
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self.to_json())
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
matching_columns = [c for c in self.columns if c.expression.name == key]
|
||||
|
||||
if (
|
||||
matching_columns
|
||||
and issubclass(type(matching_columns[0].columns[0].type), DateTime)
|
||||
and isinstance(value, str)
|
||||
):
|
||||
value = datetime.fromisoformat(value)
|
||||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def get_plugin(self):
|
||||
from platypush.context import get_plugin
|
||||
|
||||
|
@ -113,10 +152,9 @@ class Entity(Base):
|
|||
assert method, f'No such action: {self.plugin}.{action}'
|
||||
return method(self.external_id or self.name, *args, **kwargs)
|
||||
|
||||
|
||||
# Inject the JSONAble mixin (Python goes nuts if done through
|
||||
# standard multiple inheritance with an SQLAlchemy ORM class)
|
||||
Entity.__bases__ = Entity.__bases__ + (JSONAble,)
|
||||
# Inject the JSONAble mixin (Python goes nuts if done through
|
||||
# standard multiple inheritance with an SQLAlchemy ORM class)
|
||||
Entity.__bases__ = Entity.__bases__ + (JSONAble,)
|
||||
|
||||
|
||||
def _discover_entity_types():
|
||||
|
|
|
@ -1,281 +0,0 @@
|
|||
import json
|
||||
from logging import getLogger
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread, Event, RLock
|
||||
from time import time
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy.exc import InvalidRequestError
|
||||
from sqlalchemy.orm import Session, make_transient
|
||||
|
||||
from platypush.context import get_bus, get_plugin
|
||||
from platypush.message.event.entities import EntityUpdateEvent
|
||||
|
||||
from ._base import Entity
|
||||
|
||||
|
||||
class EntitiesEngine(Thread):
|
||||
# Processing queue timeout in seconds
|
||||
_queue_timeout = 2.0
|
||||
|
||||
def __init__(self):
|
||||
obj_name = self.__class__.__name__
|
||||
super().__init__(name=obj_name)
|
||||
self.logger = getLogger(name=obj_name)
|
||||
self._queue = Queue()
|
||||
self._should_stop = Event()
|
||||
self._entities_awaiting_flush = set()
|
||||
self._entities_cache_lock = RLock()
|
||||
self._entities_cache = {
|
||||
'by_id': {},
|
||||
'by_external_id_and_plugin': {},
|
||||
'by_name_and_plugin': {},
|
||||
}
|
||||
|
||||
def _get_session(self, *args, **kwargs):
|
||||
db = get_plugin('db')
|
||||
assert db
|
||||
return db.get_session(*args, **kwargs)
|
||||
|
||||
def _get_cached_entity(self, entity: Entity) -> Optional[dict]:
|
||||
if entity.id:
|
||||
e = self._entities_cache['by_id'].get(entity.id)
|
||||
if e:
|
||||
return e
|
||||
|
||||
if entity.external_id and entity.plugin:
|
||||
e = self._entities_cache['by_external_id_and_plugin'].get(
|
||||
(entity.external_id, entity.plugin)
|
||||
)
|
||||
if e:
|
||||
return e
|
||||
|
||||
if entity.name and entity.plugin:
|
||||
e = self._entities_cache['by_name_and_plugin'].get(
|
||||
(entity.name, entity.plugin)
|
||||
)
|
||||
if e:
|
||||
return e
|
||||
|
||||
@staticmethod
|
||||
def _cache_repr(entity: Entity) -> dict:
|
||||
repr_ = entity.to_json()
|
||||
repr_.pop('data', None)
|
||||
repr_.pop('meta', None)
|
||||
repr_.pop('created_at', None)
|
||||
repr_.pop('updated_at', None)
|
||||
return repr_
|
||||
|
||||
def _cache_entities(self, *entities: Entity, overwrite_cache=False):
|
||||
for entity in entities:
|
||||
e = self._cache_repr(entity)
|
||||
if not overwrite_cache:
|
||||
existing_entity = self._entities_cache['by_id'].get(entity.id)
|
||||
if existing_entity:
|
||||
for k, v in existing_entity.items():
|
||||
if e.get(k) is None:
|
||||
e[k] = v
|
||||
|
||||
if entity.id:
|
||||
self._entities_cache['by_id'][entity.id] = e
|
||||
if entity.external_id and entity.plugin:
|
||||
self._entities_cache['by_external_id_and_plugin'][
|
||||
(entity.external_id, entity.plugin)
|
||||
] = e
|
||||
if entity.name and entity.plugin:
|
||||
self._entities_cache['by_name_and_plugin'][
|
||||
(entity.name, entity.plugin)
|
||||
] = e
|
||||
|
||||
def _populate_entity_id_from_cache(self, new_entity: Entity):
|
||||
with self._entities_cache_lock:
|
||||
cached_entity = self._get_cached_entity(new_entity)
|
||||
if cached_entity and cached_entity.get('id'):
|
||||
new_entity.id = cached_entity['id']
|
||||
if new_entity.id:
|
||||
self._cache_entities(new_entity)
|
||||
|
||||
def _init_entities_cache(self):
|
||||
with self._get_session() as session:
|
||||
entities = session.query(Entity).all()
|
||||
for entity in entities:
|
||||
make_transient(entity)
|
||||
|
||||
with self._entities_cache_lock:
|
||||
self._cache_entities(*entities, overwrite_cache=True)
|
||||
|
||||
self.logger.info('Entities cache initialized')
|
||||
|
||||
def _process_event(self, entity: Entity):
|
||||
self._populate_entity_id_from_cache(entity)
|
||||
if entity.id:
|
||||
get_bus().post(EntityUpdateEvent(entity=entity))
|
||||
else:
|
||||
self._entities_awaiting_flush.add(self._to_entity_awaiting_flush(entity))
|
||||
|
||||
@staticmethod
|
||||
def _to_entity_awaiting_flush(entity: Entity):
|
||||
e = entity.to_json()
|
||||
return json.dumps(
|
||||
{k: v for k, v in e.items() if k in {'external_id', 'name', 'plugin'}},
|
||||
sort_keys=True,
|
||||
)
|
||||
|
||||
def post(self, *entities: Entity):
|
||||
for entity in entities:
|
||||
self._queue.put(entity)
|
||||
|
||||
@property
|
||||
def should_stop(self) -> bool:
|
||||
return self._should_stop.is_set()
|
||||
|
||||
def stop(self):
|
||||
self._should_stop.set()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
self.logger.info('Started entities engine')
|
||||
self._init_entities_cache()
|
||||
|
||||
while not self.should_stop:
|
||||
msgs = []
|
||||
last_poll_time = time()
|
||||
|
||||
while not self.should_stop and (
|
||||
time() - last_poll_time < self._queue_timeout
|
||||
):
|
||||
try:
|
||||
msg = self._queue.get(block=True, timeout=0.5)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
# Trigger an EntityUpdateEvent if there has
|
||||
# been a change on the entity state
|
||||
self._process_event(msg)
|
||||
|
||||
if not msgs or self.should_stop:
|
||||
continue
|
||||
|
||||
try:
|
||||
self._process_entities(*msgs)
|
||||
except Exception as e:
|
||||
self.logger.error('Error while processing entity updates: ' + str(e))
|
||||
self.logger.exception(e)
|
||||
|
||||
self.logger.info('Stopped entities engine')
|
||||
|
||||
def get_if_exist(
|
||||
self, session: Session, entities: Iterable[Entity]
|
||||
) -> List[Entity]:
|
||||
existing_entities = {
|
||||
(
|
||||
str(entity.external_id)
|
||||
if entity.external_id is not None
|
||||
else entity.name,
|
||||
entity.plugin,
|
||||
): entity
|
||||
for entity in session.query(Entity)
|
||||
.filter(
|
||||
or_(
|
||||
*[
|
||||
and_(
|
||||
Entity.external_id == entity.external_id,
|
||||
Entity.plugin == entity.plugin,
|
||||
)
|
||||
if entity.external_id is not None
|
||||
else and_(
|
||||
Entity.name == entity.name,
|
||||
Entity.type == entity.type,
|
||||
Entity.plugin == entity.plugin,
|
||||
)
|
||||
for entity in entities
|
||||
]
|
||||
)
|
||||
)
|
||||
.all()
|
||||
}
|
||||
|
||||
return [
|
||||
existing_entities.get(
|
||||
(
|
||||
str(entity.external_id)
|
||||
if entity.external_id is not None
|
||||
else entity.name,
|
||||
entity.plugin,
|
||||
),
|
||||
None,
|
||||
)
|
||||
for entity in entities
|
||||
]
|
||||
|
||||
def _process_entities(self, *entities: Entity): # type: ignore
|
||||
with self._get_session(locked=True) as session:
|
||||
# Ensure that the internal IDs are set to null before the merge
|
||||
for e in entities:
|
||||
e.id = None # type: ignore
|
||||
|
||||
entities: List[Entity] = self._merge_entities(session, entities)
|
||||
session.add_all(entities)
|
||||
session.commit()
|
||||
|
||||
for e in entities:
|
||||
try:
|
||||
session.expunge(e)
|
||||
except InvalidRequestError:
|
||||
pass
|
||||
|
||||
with self._entities_cache_lock:
|
||||
for entity in entities:
|
||||
self._cache_entities(entity, overwrite_cache=True)
|
||||
|
||||
entities_awaiting_flush = {*self._entities_awaiting_flush}
|
||||
for entity in entities:
|
||||
e = self._to_entity_awaiting_flush(entity)
|
||||
if e in entities_awaiting_flush:
|
||||
self._process_event(entity)
|
||||
self._entities_awaiting_flush.remove(e)
|
||||
|
||||
def _merge_entities(
|
||||
self, session: Session, entities: Iterable[Entity]
|
||||
) -> List[Entity]:
|
||||
existing_entities = self.get_if_exist(session, entities)
|
||||
new_entities = {}
|
||||
entities_map = {}
|
||||
|
||||
# Get the latest update for each ((id|name), plugin) record
|
||||
for e in entities:
|
||||
entities_map[self.entity_key(e)] = e
|
||||
|
||||
# Retrieve existing records and merge them
|
||||
for i, entity in enumerate(entities):
|
||||
existing_entity = existing_entities[i]
|
||||
if existing_entity:
|
||||
existing_entity.children = self._merge_entities(
|
||||
session, entity.children
|
||||
)
|
||||
entity = self._merge_entity_columns(entity, existing_entity)
|
||||
|
||||
new_entities[self.entity_key(entity)] = entity
|
||||
|
||||
return list(new_entities.values())
|
||||
|
||||
@classmethod
|
||||
def _merge_entity_columns(cls, entity: Entity, existing_entity: Entity) -> Entity:
|
||||
columns = [col.key for col in entity.columns]
|
||||
for col in columns:
|
||||
if col == 'meta':
|
||||
existing_entity.meta = { # type: ignore
|
||||
**(existing_entity.meta or {}), # type: ignore
|
||||
**(entity.meta or {}), # type: ignore
|
||||
}
|
||||
elif col not in ('id', 'created_at'):
|
||||
setattr(existing_entity, col, getattr(entity, col))
|
||||
|
||||
return existing_entity
|
||||
|
||||
@staticmethod
|
||||
def entity_key(entity: Entity):
|
||||
return ((entity.external_id or entity.name), entity.plugin)
|
76
platypush/entities/_engine/__init__.py
Normal file
76
platypush/entities/_engine/__init__.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
from logging import getLogger
|
||||
from threading import Thread, Event
|
||||
|
||||
from platypush.entities import Entity
|
||||
from platypush.utils import set_thread_name
|
||||
|
||||
from platypush.entities._engine.notifier import EntityNotifier
|
||||
from platypush.entities._engine.queue import EntitiesQueue
|
||||
from platypush.entities._engine.repo import EntitiesRepository
|
||||
|
||||
|
||||
class EntitiesEngine(Thread):
|
||||
"""
|
||||
This thread runs the "brain" of the entities data persistence logic.
|
||||
|
||||
Its purpose is to:
|
||||
|
||||
1. Consume entities from a queue (synchronized with the upstream
|
||||
integrations that produce/handle them). The producer/consumer model
|
||||
ensure that only this thread writes to the database, packs events
|
||||
together (preventing eccessive writes and throttling events), and
|
||||
prevents race conditions when SQLite is used.
|
||||
2. Merge any existing entities with their newer representations.
|
||||
3. Update the entities taxonomy.
|
||||
4. Persist the new state to the entities database.
|
||||
5. Trigger events for the updated entities.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
obj_name = self.__class__.__name__
|
||||
super().__init__(name=obj_name)
|
||||
|
||||
self.logger = getLogger(name=obj_name)
|
||||
self._should_stop = Event()
|
||||
self._queue = EntitiesQueue(stop_event=self._should_stop)
|
||||
self._repo = EntitiesRepository()
|
||||
self._notifier = EntityNotifier(self._repo._cache)
|
||||
|
||||
def post(self, *entities: Entity):
|
||||
self._queue.put(*entities)
|
||||
|
||||
@property
|
||||
def should_stop(self) -> bool:
|
||||
return self._should_stop.is_set()
|
||||
|
||||
def stop(self):
|
||||
self._should_stop.set()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
set_thread_name('entities')
|
||||
self.logger.info('Started entities engine')
|
||||
|
||||
while not self.should_stop:
|
||||
# Get a batch of entity updates forwarded by other integrations
|
||||
entities = self._queue.get()
|
||||
if not entities or self.should_stop:
|
||||
continue
|
||||
|
||||
# Trigger/prepare EntityUpdateEvent objects
|
||||
for entity in entities:
|
||||
self._notifier.notify(entity)
|
||||
|
||||
# Store the batch of entities
|
||||
try:
|
||||
entities = self._repo.save(*entities)
|
||||
except Exception as e:
|
||||
self.logger.error('Error while processing entity updates: ' + str(e))
|
||||
self.logger.exception(e)
|
||||
continue
|
||||
|
||||
# Flush any pending notifications
|
||||
self._notifier.flush(*entities)
|
||||
|
||||
self.logger.info('Stopped entities engine')
|
47
platypush/entities/_engine/notifier.py
Normal file
47
platypush/entities/_engine/notifier.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from platypush.context import get_bus
|
||||
from platypush.entities import Entity
|
||||
from platypush.message.event.entities import EntityUpdateEvent
|
||||
|
||||
from platypush.entities._engine.repo.cache import EntitiesCache
|
||||
|
||||
|
||||
class EntityNotifier:
|
||||
"""
|
||||
This object is in charge of forwarding EntityUpdateEvent instances on the
|
||||
application bus when some entities are changed.
|
||||
"""
|
||||
|
||||
def __init__(self, cache: EntitiesCache):
|
||||
self._cache = cache
|
||||
self._entities_awaiting_flush = set()
|
||||
|
||||
def _populate_entity_id_from_cache(self, new_entity: Entity):
|
||||
cached_entity = self._cache.get(new_entity)
|
||||
if cached_entity and cached_entity.id:
|
||||
new_entity.id = cached_entity.id
|
||||
if new_entity.id:
|
||||
self._cache.update(new_entity)
|
||||
|
||||
def notify(self, entity: Entity):
|
||||
"""
|
||||
Trigger an EntityUpdateEvent if the entity has been persisted, or queue
|
||||
it to the list of entities whose notifications will be flushed when the
|
||||
session is committed.
|
||||
"""
|
||||
self._populate_entity_id_from_cache(entity)
|
||||
if entity.id:
|
||||
get_bus().post(EntityUpdateEvent(entity=entity))
|
||||
else:
|
||||
self._entities_awaiting_flush.add(entity.entity_key)
|
||||
|
||||
def flush(self, *entities: Entity):
|
||||
"""
|
||||
Flush and process any entities with pending EntityUpdateEvent
|
||||
notifications.
|
||||
"""
|
||||
entities_awaiting_flush = {*self._entities_awaiting_flush}
|
||||
for entity in entities:
|
||||
key = entity.entity_key
|
||||
if key in entities_awaiting_flush:
|
||||
self.notify(entity)
|
||||
self._entities_awaiting_flush.remove(key)
|
50
platypush/entities/_engine/queue.py
Normal file
50
platypush/entities/_engine/queue.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from queue import Queue, Empty
|
||||
from threading import Event
|
||||
from time import time
|
||||
from typing import List, Optional
|
||||
|
||||
from platypush.entities import Entity
|
||||
|
||||
|
||||
class EntitiesQueue(Queue):
|
||||
"""
|
||||
Extends the ``Queue`` class to provide an abstraction that allows to
|
||||
getting and putting multiple entities at once and synchronize with the
|
||||
upstream caller.
|
||||
"""
|
||||
|
||||
def __init__(self, stop_event: Optional[Event] = None, timeout: float = 2.0):
|
||||
super().__init__()
|
||||
self._timeout = timeout
|
||||
self._should_stop = stop_event
|
||||
|
||||
@property
|
||||
def should_stop(self) -> bool:
|
||||
return self._should_stop.is_set() if self._should_stop else False
|
||||
|
||||
def get(self, block=True, timeout=None) -> List[Entity]:
|
||||
"""
|
||||
Returns a batch of entities read from the queue.
|
||||
"""
|
||||
timeout = timeout or self._timeout
|
||||
entities = []
|
||||
last_poll_time = time()
|
||||
|
||||
while not self.should_stop and (time() - last_poll_time < timeout):
|
||||
try:
|
||||
entity = super().get(block=block, timeout=0.5)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
if entity:
|
||||
entities.append(entity)
|
||||
|
||||
return entities
|
||||
|
||||
def put(self, *entities: Entity, block=True, timeout=None):
|
||||
"""
|
||||
This methood is called by an entity manager to update and persist the
|
||||
state of some entities.
|
||||
"""
|
||||
for entity in entities:
|
||||
super().put(entity, block=block, timeout=timeout)
|
86
platypush/entities/_engine/repo/__init__.py
Normal file
86
platypush/entities/_engine/repo/__init__.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import logging
|
||||
from typing import Dict, Iterable, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session, make_transient
|
||||
|
||||
from platypush.entities import Entity
|
||||
from platypush.entities._engine.repo.cache import EntitiesCache
|
||||
from platypush.entities._engine.repo.db import EntitiesDb
|
||||
from platypush.entities._engine.repo.merger import EntitiesMerger
|
||||
|
||||
logger = logging.getLogger('entities')
|
||||
|
||||
|
||||
class EntitiesRepository:
|
||||
"""
|
||||
This object is used to get and save entities, and it wraps database and
|
||||
cache objects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._cache = EntitiesCache()
|
||||
self._db = EntitiesDb()
|
||||
self._merger = EntitiesMerger(self)
|
||||
self._init_entities_cache()
|
||||
|
||||
def _init_entities_cache(self):
|
||||
"""
|
||||
Initializes the repository with the existing entities.
|
||||
"""
|
||||
logger.info('Initializing entities cache')
|
||||
with self._db.get_session() as session:
|
||||
entities = session.query(Entity).all()
|
||||
for entity in entities:
|
||||
make_transient(entity)
|
||||
|
||||
self._cache.update(*entities, overwrite=True)
|
||||
logger.info('Entities cache initialized')
|
||||
|
||||
def get(
|
||||
self, session: Session, entities: Iterable[Entity]
|
||||
) -> Dict[Tuple[str, str], Entity]:
|
||||
"""
|
||||
Given a set of entity objects, it returns those that already exist
|
||||
(or have the same ``entity_key``). It looks up both the cache and the
|
||||
database.
|
||||
"""
|
||||
entities_map: Dict[Tuple[str, str], Entity] = {
|
||||
e.entity_key: e for e in entities
|
||||
}
|
||||
|
||||
# Fetch the entities that exist in the cache
|
||||
existing_entities = {}
|
||||
# TODO UNCOMMENT THIS CODE TO ACTUALLY USE THE CACHE!
|
||||
# existing_entities = {
|
||||
# key: self._entities_cache.by_external_id_and_plugin[key]
|
||||
# for key in entities_map.keys()
|
||||
# if key in self._entities_cache.by_external_id_and_plugin
|
||||
# }
|
||||
|
||||
# Retrieve from the database the entities that miss from the cache
|
||||
cache_miss_entities = {
|
||||
key: e for key, e in entities_map.items() if key not in existing_entities
|
||||
}
|
||||
|
||||
cache_miss_existing_entities = self._db.fetch(
|
||||
session, cache_miss_entities.values()
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
self._cache.update(*cache_miss_existing_entities.values())
|
||||
|
||||
# Return the union of the cached + retrieved entities
|
||||
existing_entities.update(cache_miss_existing_entities)
|
||||
return existing_entities
|
||||
|
||||
def save(self, *entities: Entity) -> Iterable[Entity]:
|
||||
"""
|
||||
Perform an upsert of entities after merging duplicates and rebuilding
|
||||
the taxonomies. It updates both the database and the cache.
|
||||
"""
|
||||
with self._db.get_session(locked=True, autoflush=False) as session:
|
||||
merged_entities = self._merger.merge(session, entities)
|
||||
merged_entities = self._db.upsert(session, merged_entities)
|
||||
self._cache.update(*merged_entities, overwrite=True)
|
||||
|
||||
return merged_entities
|
51
platypush/entities/_engine/repo/cache.py
Normal file
51
platypush/entities/_engine/repo/cache.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
from threading import RLock
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from platypush.entities import Entity
|
||||
|
||||
|
||||
class EntitiesCache:
|
||||
"""
|
||||
An auxiliary class to model an entities lookup cache with multiple keys.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.by_id: Dict[str, Entity] = {}
|
||||
self.by_external_id_and_plugin: Dict[Tuple[str, str], Entity] = {}
|
||||
self._lock = RLock()
|
||||
|
||||
def get(self, entity: Entity) -> Optional[Entity]:
|
||||
"""
|
||||
Retrieve the cached representation of an entity, if it exists.
|
||||
"""
|
||||
if entity.id:
|
||||
e = self.by_id.get(str(entity.id))
|
||||
if e:
|
||||
return e
|
||||
|
||||
if entity.external_id and entity.plugin:
|
||||
e = self.by_external_id_and_plugin.get(
|
||||
(str(entity.external_id), str(entity.plugin))
|
||||
)
|
||||
if e:
|
||||
return e
|
||||
|
||||
def update(self, *entities: Entity, overwrite=False):
|
||||
"""
|
||||
Update the cache with a list of new entities.
|
||||
"""
|
||||
with self._lock:
|
||||
for entity in entities:
|
||||
if not overwrite:
|
||||
existing_entity = self.by_id.get(str(entity.id))
|
||||
if existing_entity:
|
||||
for k, v in existing_entity.to_json().items():
|
||||
if getattr(entity, k, None) is None:
|
||||
setattr(entity, k, v)
|
||||
|
||||
if entity.id:
|
||||
self.by_id[str(entity.id)] = entity
|
||||
if entity.external_id and entity.plugin:
|
||||
self.by_external_id_and_plugin[
|
||||
(str(entity.external_id), str(entity.plugin))
|
||||
] = entity
|
187
platypush/entities/_engine/repo/db.py
Normal file
187
platypush/entities/_engine/repo/db.py
Normal file
|
@ -0,0 +1,187 @@
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy.exc import InvalidRequestError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from platypush.context import get_plugin
|
||||
from platypush.entities import Entity
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TaxonomyAwareEntity:
|
||||
"""
|
||||
A support class used to map an entity and its level within a taxonomy to be
|
||||
flushed to the database.
|
||||
"""
|
||||
|
||||
entity: Entity
|
||||
level: int
|
||||
|
||||
|
||||
class EntitiesDb:
|
||||
"""
|
||||
This object is a facade around the entities database. It shouldn't be used
|
||||
directly. Instead, it is encapsulated by
|
||||
:class:`platypush.entities._repo.EntitiesRepository`, which is in charge of
|
||||
caching as well.
|
||||
"""
|
||||
|
||||
def get_session(self, *args, **kwargs) -> Session:
|
||||
db = get_plugin('db')
|
||||
assert db
|
||||
return db.get_session(*args, **kwargs)
|
||||
|
||||
def fetch(
|
||||
self, session: Session, entities: Iterable[Entity]
|
||||
) -> Dict[Tuple[str, str], Entity]:
|
||||
"""
|
||||
Given a set of entities, it returns those that already exist on the database.
|
||||
"""
|
||||
if not entities:
|
||||
return {}
|
||||
|
||||
entities_filter = or_(
|
||||
*[
|
||||
and_(
|
||||
Entity.external_id == entity.external_id,
|
||||
Entity.plugin == entity.plugin,
|
||||
)
|
||||
for entity in entities
|
||||
]
|
||||
)
|
||||
|
||||
query = session.query(Entity).filter(entities_filter)
|
||||
existing_entities = {entity.entity_key: entity for entity in query.all()}
|
||||
|
||||
return {
|
||||
entity.entity_key: existing_entities[entity.entity_key]
|
||||
for entity in entities
|
||||
if existing_entities.get(entity.entity_key)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _close_batch(batch: List[_TaxonomyAwareEntity], batches: List[List[Entity]]):
|
||||
if batch:
|
||||
batches.append([item.entity for item in batch])
|
||||
|
||||
batch.clear()
|
||||
|
||||
def _split_entity_batches_for_flush(
|
||||
self, entities: Iterable[Entity]
|
||||
) -> List[List[Entity]]:
|
||||
"""
|
||||
This method retrieves the root entities given a list of entities and
|
||||
generates batches of "flushable" entities ready for upsert using a BFS
|
||||
algorithm.
|
||||
|
||||
This is needed because we want hierarchies of entities to be flushed
|
||||
starting from the top layer, once their parents have been appropriately
|
||||
rewired. Otherwise, we may end up with conflicts on entities that have
|
||||
already been flushed.
|
||||
"""
|
||||
# Index childrens by parent_id and by parent_key
|
||||
children_by_parent_id = defaultdict(list)
|
||||
children_by_parent_key = defaultdict(list)
|
||||
for entity in entities:
|
||||
parent_key = None
|
||||
parent_id = entity.parent_id
|
||||
if entity.parent:
|
||||
parent_id = parent_id or entity.parent.id
|
||||
parent_key = entity.parent.entity_key
|
||||
|
||||
if parent_id:
|
||||
children_by_parent_id[parent_id].append(entity)
|
||||
if parent_key:
|
||||
children_by_parent_key[parent_key].append(entity)
|
||||
|
||||
# Find the root entities in the hierarchy (i.e. those that have a null
|
||||
# parent)
|
||||
root_entities = list(
|
||||
{
|
||||
e.entity_key: e
|
||||
for e in entities
|
||||
if e.parent is None and e.parent_id is None
|
||||
}.values()
|
||||
)
|
||||
|
||||
# Prepare a list of entities to process through BFS starting with the
|
||||
# root nodes (i.e. level=0)
|
||||
entities_to_process = [
|
||||
_TaxonomyAwareEntity(entity=e, level=0) for e in root_entities
|
||||
]
|
||||
|
||||
batches = []
|
||||
current_batch = []
|
||||
|
||||
while entities_to_process:
|
||||
# Pop the first element in the list (FIFO implementation)
|
||||
item = entities_to_process.pop(0)
|
||||
entity = item.entity
|
||||
level = item.level
|
||||
|
||||
# If the depth has increased compared to the previous node, flush
|
||||
# the current batch and open a new one.
|
||||
if current_batch and current_batch[-1].level < level:
|
||||
self._close_batch(current_batch, batches)
|
||||
current_batch.append(item)
|
||||
|
||||
# Index the children nodes by key
|
||||
children_to_process = {
|
||||
e.entity_key: e
|
||||
for e in children_by_parent_key.get(entity.entity_key, [])
|
||||
}
|
||||
|
||||
# If this entity has already been persisted, add back its children
|
||||
# that haven't been updated, so we won't lose those connections
|
||||
if entity.id:
|
||||
children_to_process.update(
|
||||
{e.entity_key: e for e in children_by_parent_id.get(entity.id, [])}
|
||||
)
|
||||
|
||||
# Add all the updated+inserted+existing children to the next layer
|
||||
# to be expanded
|
||||
entities_to_process += [
|
||||
_TaxonomyAwareEntity(entity=e, level=level + 1)
|
||||
for e in children_to_process.values()
|
||||
]
|
||||
|
||||
# Close any pending batches
|
||||
self._close_batch(current_batch, batches)
|
||||
return batches
|
||||
|
||||
def upsert(
|
||||
self,
|
||||
session: Session,
|
||||
entities: Iterable[Entity],
|
||||
) -> Iterable[Entity]:
|
||||
"""
|
||||
Persist a set of entities.
|
||||
"""
|
||||
# Get the "unwrapped" batches
|
||||
batches = self._split_entity_batches_for_flush(entities)
|
||||
|
||||
# Flush each batch as we process it
|
||||
for batch in batches:
|
||||
session.add_all(batch)
|
||||
session.flush()
|
||||
|
||||
session.commit()
|
||||
|
||||
all_entities = list(
|
||||
{
|
||||
entity.entity_key: entity for batch in batches for entity in batch
|
||||
}.values()
|
||||
)
|
||||
|
||||
# Remove all the entities from the existing session, so they can be
|
||||
# accessed outside of this context
|
||||
for e in all_entities:
|
||||
try:
|
||||
session.expunge(e)
|
||||
except InvalidRequestError:
|
||||
pass
|
||||
|
||||
return all_entities
|
143
platypush/entities/_engine/repo/merger.py
Normal file
143
platypush/entities/_engine/repo/merger.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from platypush.entities import Entity
|
||||
|
||||
|
||||
class EntitiesMerger:
|
||||
"""
|
||||
This object is in charge of detecting and merging entities that already
|
||||
exist on the database before flushing the session.
|
||||
"""
|
||||
|
||||
def __init__(self, repository):
|
||||
from platypush.entities._engine.repo import EntitiesRepository
|
||||
|
||||
self._repo: EntitiesRepository = repository
|
||||
|
||||
def merge(
|
||||
self,
|
||||
session: Session,
|
||||
entities: Iterable[Entity],
|
||||
) -> List[Entity]:
|
||||
"""
|
||||
Merge a set of entities with their existing representations and update
|
||||
the parent/child relationships and return a tuple with
|
||||
``[new_entities, updated_entities]``.
|
||||
"""
|
||||
new_entities = {}
|
||||
existing_entities = {}
|
||||
|
||||
self._merge(
|
||||
session,
|
||||
entities,
|
||||
new_entities=new_entities,
|
||||
existing_entities=existing_entities,
|
||||
)
|
||||
|
||||
return [*existing_entities.values(), *new_entities.values()]
|
||||
|
||||
def _merge(
|
||||
self,
|
||||
session: Session,
|
||||
entities: Iterable[Entity],
|
||||
new_entities: Dict[Tuple[str, str], Entity],
|
||||
existing_entities: Dict[Tuple[str, str], Entity],
|
||||
) -> List[Entity]:
|
||||
"""
|
||||
(Recursive) inner implementation of the entity merge logic.
|
||||
"""
|
||||
processed_entities = []
|
||||
existing_entities.update(self._repo.get(session, entities))
|
||||
|
||||
# Retrieve existing records and merge them
|
||||
for entity in entities:
|
||||
key = entity.entity_key
|
||||
existing_entity = existing_entities.get(key, new_entities.get(key))
|
||||
parent_id, parent = self._update_parent(session, entity, new_entities)
|
||||
|
||||
if existing_entity:
|
||||
# Update the parent
|
||||
if not parent_id and parent:
|
||||
existing_entity.parent = parent
|
||||
else:
|
||||
existing_entity.parent_id = parent_id # type: ignore
|
||||
|
||||
# Merge the other columns
|
||||
self._merge_columns(entity, existing_entity)
|
||||
entity = existing_entity
|
||||
else:
|
||||
# Add it to the map of new entities if the entity doesn't exist
|
||||
# on the repo
|
||||
new_entities[key] = entity
|
||||
|
||||
processed_entities.append(entity)
|
||||
|
||||
return processed_entities
|
||||
|
||||
def _update_parent(
|
||||
self,
|
||||
session: Session,
|
||||
entity: Entity,
|
||||
new_entities: Dict[Tuple[str, str], Entity],
|
||||
) -> Tuple[Optional[int], Optional[Entity]]:
|
||||
"""
|
||||
Recursively update the hierarchy of an entity, moving upwards towards
|
||||
the parent.
|
||||
"""
|
||||
parent_id: Optional[int] = entity.parent_id # type: ignore
|
||||
parent: Optional[Entity] = entity.parent
|
||||
|
||||
# If the entity has a parent with an ID, use that
|
||||
if parent and parent.id:
|
||||
parent_id = parent_id or parent.id
|
||||
|
||||
# If there's no parent_id but there is a parent object, try to fetch
|
||||
# its stored version
|
||||
if not parent_id and parent:
|
||||
batch = list(self._repo.get(session, [parent]).values())
|
||||
|
||||
# If the parent is already stored, use its ID
|
||||
if batch:
|
||||
parent = batch[0]
|
||||
parent_id = parent.id # type: ignore
|
||||
|
||||
# Otherwise, check if its key is already among those awaiting flush
|
||||
# and reuse the same objects (prevents SQLAlchemy from generating
|
||||
# duplicate inserts)
|
||||
else:
|
||||
temp_entity = new_entities.get(parent.entity_key)
|
||||
if temp_entity:
|
||||
parent = entity.parent = temp_entity
|
||||
else:
|
||||
new_entities[parent.entity_key] = parent
|
||||
|
||||
# Recursively apply any changes up in the hierarchy
|
||||
self._update_parent(session, parent, new_entities=new_entities)
|
||||
|
||||
# If we found a parent_id, populate it on the entity (and remove the
|
||||
# supporting relationship object so SQLAlchemy doesn't go nuts when
|
||||
# flushing)
|
||||
if parent_id:
|
||||
entity.parent = None
|
||||
entity.parent_id = parent_id # type: ignore
|
||||
|
||||
return parent_id, parent
|
||||
|
||||
@classmethod
|
||||
def _merge_columns(cls, entity: Entity, existing_entity: Entity) -> Entity:
|
||||
"""
|
||||
Merge two versions of an entity column by column.
|
||||
"""
|
||||
columns = [col.key for col in entity.columns]
|
||||
for col in columns:
|
||||
if col == 'meta':
|
||||
existing_entity.meta = { # type: ignore
|
||||
**(existing_entity.meta or {}), # type: ignore
|
||||
**(entity.meta or {}), # type: ignore
|
||||
}
|
||||
elif col not in ('id', 'created_at'):
|
||||
setattr(existing_entity, col, getattr(entity, col))
|
||||
|
||||
return existing_entity
|
Loading…
Reference in a new issue