From 17615ff028e591752ec721ee0cd7c21fcf594382 Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Sun, 10 Apr 2022 21:23:03 +0200 Subject: [PATCH] Support for multiple entity types/plugins filter on entities.get --- platypush/plugins/entities/__init__.py | 65 ++++++++++++++++---------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/platypush/plugins/entities/__init__.py b/platypush/plugins/entities/__init__.py index 40918d9b0..72b3bdb6e 100644 --- a/platypush/plugins/entities/__init__.py +++ b/platypush/plugins/entities/__init__.py @@ -1,7 +1,7 @@ from queue import Queue, Empty from threading import Thread from time import time -from typing import Optional, Any +from typing import Optional, Any, Collection from platypush.context import get_plugin from platypush.entities import Entity, get_plugin_entity_registry, get_entities_registry @@ -23,26 +23,43 @@ class EntitiesPlugin(Plugin): return db @action - def get(self, type: str = 'entity', **filter): + def get( + self, + types: Optional[Collection[str]] = None, + plugins: Optional[Collection[str]] = None, + **filter, + ): """ Retrieve a list of entities. - :param type: Entity type, as specified by the (lowercase) class name and table name. - Default: `entity` (retrieve all the types) + :param types: Entity types, as specified by the (lowercase) class name and table name. + Default: all entities. + :param plugins: Filter by plugin IDs (default: all plugins). :param filter: Filter entities with these criteria (e.g. `name`, `id`, - `state`, `plugin` etc.) + `state`, `type`, `plugin` etc.) """ entity_registry = get_entities_registry() + selected_types = [] all_types = {e.__tablename__.lower(): e for e in entity_registry} - entity_type = all_types.get(type.lower()) - assert ( - entity_type - ), f'No such entity type: {type}. Supported types: {list(all_types.keys())}' + if types: + selected_types = {t.lower() for t in types} + entity_types = {t: et for t, et in all_types.items() if t in selected_types} + invalid_types = selected_types.difference(entity_types.keys()) + assert not invalid_types, ( + f'No such entity types: {invalid_types}. ' + f'Supported types: {list(all_types.keys())}' + ) + + selected_types = entity_types.keys() db = self._get_db() with db.get_session() as session: - query = session.query(entity_type) + query = session.query(Entity) + if selected_types: + query = query.filter(Entity.type.in_(selected_types)) + if plugins: + query = query.filter(Entity.plugin.in_(plugins)) if filter: query = query.filter_by(**filter) @@ -51,37 +68,35 @@ class EntitiesPlugin(Plugin): @action def scan( self, - type: Optional[str] = None, - plugin: Optional[str] = None, + types: Optional[Collection[str]] = None, + plugins: Optional[Collection[str]] = None, timeout: Optional[float] = 30.0, ): """ (Re-)scan entities and return the updated results. - :param type: Filter by entity type (e.g. `switch`, `light`, `sensor` etc.). Default: all. - :param plugin: Filter by plugin name (e.g. `switch.tplink` or `light.hue`). Default: all. + :param types: Filter by entity types (e.g. `switch`, `light`, `sensor` etc.). + :param plugins: Filter by plugin names (e.g. `switch.tplink` or `light.hue`). :param timeout: Scan timeout in seconds. Default: 30. """ filter = {} plugin_registry = get_plugin_entity_registry() - if plugin: - filter['plugin'] = plugin + if plugins: + filter['plugins'] = plugins plugin_registry['by_plugin'] = { - **( - {plugin: plugin_registry['by_plugin'][plugin]} - if plugin in plugin_registry['by_plugin'] - else {} - ) + plugin: plugin_registry['by_plugin'][plugin] + for plugin in plugins + if plugin in plugin_registry['by_plugin'] } - if type: - filter['type'] = type - filter_plugins = set(plugin_registry['by_entity_type'].get(type, [])) + if types: + filter['types'] = types + filter_entity_types = set(types) plugin_registry['by_plugin'] = { plugin_name: entity_types for plugin_name, entity_types in plugin_registry['by_plugin'].items() - if plugin_name in filter_plugins + if any(t for t in entity_types if t in filter_entity_types) } enabled_plugins = plugin_registry['by_plugin'].keys()