Support for multiple entity types/plugins filter on entities.get

This commit is contained in:
Fabio Manganiello 2022-04-10 21:23:03 +02:00
parent 532217be12
commit 17615ff028
Signed by: blacklight
GPG key ID: D90FBA7F76362774

View file

@ -1,7 +1,7 @@
from queue import Queue, Empty from queue import Queue, Empty
from threading import Thread from threading import Thread
from time import time from time import time
from typing import Optional, Any from typing import Optional, Any, Collection
from platypush.context import get_plugin from platypush.context import get_plugin
from platypush.entities import Entity, get_plugin_entity_registry, get_entities_registry from platypush.entities import Entity, get_plugin_entity_registry, get_entities_registry
@ -23,26 +23,43 @@ class EntitiesPlugin(Plugin):
return db return db
@action @action
def get(self, type: str = 'entity', **filter): def get(
self,
types: Optional[Collection[str]] = None,
plugins: Optional[Collection[str]] = None,
**filter,
):
""" """
Retrieve a list of entities. Retrieve a list of entities.
:param type: Entity type, as specified by the (lowercase) class name and table name. :param types: Entity types, as specified by the (lowercase) class name and table name.
Default: `entity` (retrieve all the types) Default: all entities.
:param plugins: Filter by plugin IDs (default: all plugins).
:param filter: Filter entities with these criteria (e.g. `name`, `id`, :param filter: Filter entities with these criteria (e.g. `name`, `id`,
`state`, `plugin` etc.) `state`, `type`, `plugin` etc.)
""" """
entity_registry = get_entities_registry() entity_registry = get_entities_registry()
selected_types = []
all_types = {e.__tablename__.lower(): e for e in entity_registry} all_types = {e.__tablename__.lower(): e for e in entity_registry}
entity_type = all_types.get(type.lower()) if types:
assert ( selected_types = {t.lower() for t in types}
entity_type entity_types = {t: et for t, et in all_types.items() if t in selected_types}
), f'No such entity type: {type}. Supported types: {list(all_types.keys())}' invalid_types = selected_types.difference(entity_types.keys())
assert not invalid_types, (
f'No such entity types: {invalid_types}. '
f'Supported types: {list(all_types.keys())}'
)
selected_types = entity_types.keys()
db = self._get_db() db = self._get_db()
with db.get_session() as session: with db.get_session() as session:
query = session.query(entity_type) query = session.query(Entity)
if selected_types:
query = query.filter(Entity.type.in_(selected_types))
if plugins:
query = query.filter(Entity.plugin.in_(plugins))
if filter: if filter:
query = query.filter_by(**filter) query = query.filter_by(**filter)
@ -51,37 +68,35 @@ class EntitiesPlugin(Plugin):
@action @action
def scan( def scan(
self, self,
type: Optional[str] = None, types: Optional[Collection[str]] = None,
plugin: Optional[str] = None, plugins: Optional[Collection[str]] = None,
timeout: Optional[float] = 30.0, timeout: Optional[float] = 30.0,
): ):
""" """
(Re-)scan entities and return the updated results. (Re-)scan entities and return the updated results.
:param type: Filter by entity type (e.g. `switch`, `light`, `sensor` etc.). Default: all. :param types: Filter by entity types (e.g. `switch`, `light`, `sensor` etc.).
:param plugin: Filter by plugin name (e.g. `switch.tplink` or `light.hue`). Default: all. :param plugins: Filter by plugin names (e.g. `switch.tplink` or `light.hue`).
:param timeout: Scan timeout in seconds. Default: 30. :param timeout: Scan timeout in seconds. Default: 30.
""" """
filter = {} filter = {}
plugin_registry = get_plugin_entity_registry() plugin_registry = get_plugin_entity_registry()
if plugin: if plugins:
filter['plugin'] = plugin filter['plugins'] = plugins
plugin_registry['by_plugin'] = { plugin_registry['by_plugin'] = {
**( plugin: plugin_registry['by_plugin'][plugin]
{plugin: plugin_registry['by_plugin'][plugin]} for plugin in plugins
if plugin in plugin_registry['by_plugin'] if plugin in plugin_registry['by_plugin']
else {}
)
} }
if type: if types:
filter['type'] = type filter['types'] = types
filter_plugins = set(plugin_registry['by_entity_type'].get(type, [])) filter_entity_types = set(types)
plugin_registry['by_plugin'] = { plugin_registry['by_plugin'] = {
plugin_name: entity_types plugin_name: entity_types
for plugin_name, entity_types in plugin_registry['by_plugin'].items() for plugin_name, entity_types in plugin_registry['by_plugin'].items()
if plugin_name in filter_plugins if any(t for t in entity_types if t in filter_entity_types)
} }
enabled_plugins = plugin_registry['by_plugin'].keys() enabled_plugins = plugin_registry['by_plugin'].keys()