diff --git a/platypush/message/event/__init__.py b/platypush/message/event/__init__.py index 61b489eee..b1dfbc5ed 100644 --- a/platypush/message/event/__init__.py +++ b/platypush/message/event/__init__.py @@ -2,9 +2,12 @@ import copy import json import logging import random +import re import time +from dataclasses import dataclass, field from datetime import date +from typing import Any from platypush.config import Config from platypush.message import Message @@ -90,6 +93,70 @@ class Event(Message): """Generate a unique event ID""" return ''.join([f'{random.randint(0, 255):02x}' for _ in range(16)]) + @staticmethod + def _is_relational_filter(filter: dict) -> bool: + """ + Check if a condition is a relational filter. + + For a condition to be a relational filter, it must have at least one + key starting with `$`. + """ + if not isinstance(filter, dict): + return False + return any(key.startswith('$') for key in filter) + + @staticmethod + def __relational_filter_matches(filter: dict, value: Any) -> bool: + """ + Return True if the conditions in the filter match the given event + arguments. + """ + for op, filter_val in filter.items(): + comparator = _event_filter_operators.get(op) + assert comparator, f'Invalid operator: {op}' + + # If this is a numeric or string filter, and one of the two values + # is null, return False - it doesn't make sense to run numeric or + # string comparison with null values. + if (op in _numeric_filter_operators or op in _string_filter_operators) and ( + filter_val is None or value is None + ): + return False + + # If this is a numeric-only or string-only filter, then the + # operands' types should be consistent with the operator. + if op in _numeric_filter_operators: + try: + value = float(value) + filter_val = float(filter_val) + except (ValueError, TypeError) as e: + raise AssertionError( + f'Could not convert either "{value}" nor "{filter_val} to a number' + ) from e + elif op in _string_filter_operators: + assert isinstance(filter_val, str) and isinstance(value, str), ( + f'Expected two strings, got "{filter_val}" ' + f'({type(filter_val)}) and "{value}" ({type(value)})' + ) + + if not comparator(value, filter_val): + return False + + return True + + @classmethod + def _relational_filter_matches(cls, filter: dict, value: Any) -> bool: + is_match = False + try: + is_match = cls.__relational_filter_matches(filter, value) + except AssertionError as e: + logger.error('Invalid filter: %s', e) + + if not is_match: + return False + + return True + def _matches_condition( self, condition: dict, @@ -102,24 +169,33 @@ class Event(Message): return False if isinstance(args[attr], str): - self._matches_argument( - argname=attr, condition_value=value, args=args, result=result - ) - - if result.is_match: - match_scores.append(result.score) + if self._is_relational_filter(value): + if not self._relational_filter_matches(value, args[attr]): + return False else: - return False - elif isinstance(value, dict): - if not isinstance(args[attr], dict): - return False + self._matches_argument( + argname=attr, condition_value=value, args=args, result=result + ) - return self._matches_condition( - condition=value, - args=args[attr], - result=result, - match_scores=match_scores, - ) + if result.is_match: + match_scores.append(result.score) + else: + return False + elif isinstance(value, dict): + if self._is_relational_filter(value): + if not self._relational_filter_matches(value, args[attr]): + return False + else: + if not isinstance(args[attr], dict): + return False + + if not self._matches_condition( + condition=value, + args=args[attr], + result=result, + match_scores=match_scores, + ): + return False elif args[attr] != value: return False @@ -188,6 +264,7 @@ class Event(Message): ) +@dataclass class EventMatchResult: """ When comparing an event against an event condition, you want to @@ -197,10 +274,9 @@ class EventMatchResult: highest score will win. """ - def __init__(self, is_match, score=0.0, parsed_args=None): - self.is_match = is_match - self.score = score - self.parsed_args = parsed_args or {} + is_match: bool + score: float = 0.0 + parsed_args: dict = field(default_factory=dict) def flatten(args): @@ -221,4 +297,19 @@ def flatten(args): flatten(args[i]) +_event_filter_operators = { + '$gt': lambda a, b: a > b, + '$gte': lambda a, b: a >= b, + '$lt': lambda a, b: a < b, + '$lte': lambda a, b: a <= b, + '$eq': lambda a, b: a == b, + '$ne': lambda a, b: a != b, + '$regex': lambda a, b: re.search(b, a), +} + +_numeric_filter_operators = {'$gt', '$gte', '$lt', '$lte'} + +_string_filter_operators = {'$regex'} + + # vim:sw=4:ts=4:et: diff --git a/tests/test_event_parse.py b/tests/test_event_parse.py index 8cef01c41..dfb45e449 100644 --- a/tests/test_event_parse.py +++ b/tests/test_event_parse.py @@ -87,6 +87,56 @@ def test_speech_recognized_event_parse(): assert not result.is_match +def test_condition_with_relational_operators(): + """ + Test relational operators used in event conditions. + """ + # Given: A condition with a relational operator. + condition = EventCondition.build( + { + 'type': 'platypush.message.event.ping.PingEvent', + 'message': {'foo': {'$gt': 25}}, + } + ) + + # When: An event with a value greater than 25 is received. + event = PingEvent(message={'foo': 26}) + + # Then: The condition is matched. + assert event.matches_condition(condition).is_match + + # When: An event with a value lower than 25 is received. + event = PingEvent(message={'foo': 24}) + + # Then: The condition is not matched. + assert not event.matches_condition(condition).is_match + + +def test_filter_with_regex_condition(): + """ + Test an event matcher with a regex filter on an attribute. + """ + # Given: A condition with a regex filter. + condition = EventCondition.build( + { + 'type': 'platypush.message.event.ping.PingEvent', + 'message': {'foo': {'$regex': '^ba[rz]'}}, + } + ) + + # When: An event with a matching string is received. + event = PingEvent(message={'foo': 'bart'}) + + # Then: The condition is matched. + assert event.matches_condition(condition).is_match + + # When: An event with a non-matching string is received. + event = PingEvent(message={'foo': 'back'}) + + # Then: The condition is not matched. + assert not event.matches_condition(condition).is_match + + if __name__ == '__main__': pytest.main()