A more resilient logic on entity copy/serialization to prevent ObjectDeletedError

This commit is contained in:
Fabio Manganiello 2023-04-13 17:16:21 +02:00
parent a499b7bc2f
commit 4c19535612

View file

@ -1,10 +1,11 @@
import logging
import inspect import inspect
import json import json
import pathlib import pathlib
import types import types
from datetime import datetime from datetime import datetime
import pkgutil import pkgutil
from typing import Callable, Dict, Final, Set, Type, Tuple, Any from typing import Callable, Dict, Final, Optional, Set, Type, Tuple, Any
from dateutil.tz import tzutc from dateutil.tz import tzutc
from sqlalchemy import ( from sqlalchemy import (
@ -20,6 +21,7 @@ from sqlalchemy import (
inspect as schema_inspect, inspect as schema_inspect,
) )
from sqlalchemy.orm import ColumnProperty, Mapped, backref, relationship from sqlalchemy.orm import ColumnProperty, Mapped, backref, relationship
from sqlalchemy.orm.exc import ObjectDeletedError
from platypush.common.db import Base from platypush.common.db import Base
from platypush.message import JSONAble, Message from platypush.message import JSONAble, Message
@ -40,6 +42,8 @@ requirements are missing, and if those plugins aren't enabled then we shouldn't
fail. fail.
""" """
logger = logging.getLogger(__name__)
if 'entity' not in Base.metadata: if 'entity' not in Base.metadata:
@ -127,8 +131,18 @@ if 'entity' not in Base.metadata:
to reuse entity objects in other threads or outside of their to reuse entity objects in other threads or outside of their
associated SQLAlchemy session context. associated SQLAlchemy session context.
""" """
def key_value_pair(col: ColumnProperty):
try:
return (col.key, getattr(self, col.key, None))
except ObjectDeletedError as e:
return None
return self.__class__( return self.__class__(
**{col.key: getattr(self, col.key, None) for col in self.columns}, **dict(
key_value_pair(col)
for col in self.columns
if key_value_pair(col) is not None
),
children=[child.copy() for child in self.children], children=[child.copy() for child in self.children],
) )
@ -140,32 +154,44 @@ if 'entity' not in Base.metadata:
return val return val
def _column_name(self, col: ColumnProperty) -> str: def _column_name(self, col: ColumnProperty) -> Optional[str]:
""" """
Normalizes the column name, taking into account native columns and Normalizes the column name, taking into account native columns and
columns mapped to properties. columns mapped to properties.
""" """
normalized_name = col.key.lstrip('_') try:
if len(col.key.lstrip('_')) == col.key or not hasattr( normalized_name = col.key.lstrip('_')
self, normalized_name if len(col.key.lstrip('_')) == col.key or not hasattr(
): self, normalized_name
return col.key # It's not a hidden column with a mapped property ):
return col.key # It's not a hidden column with a mapped property
return normalized_name return normalized_name
except ObjectDeletedError as e:
logger.warning(
f'Could not access column "{col.key}" for entity ID "{self.id}": {e}'
)
return None
def _column_to_pair(self, col: ColumnProperty) -> Tuple[str, Any]: def _column_to_pair(self, col: ColumnProperty) -> tuple:
""" """
Utility method that, given a column, returns a pair containing its Utility method that, given a column, returns a pair containing its
normalized name and its serialized value. normalized name and its serialized value.
""" """
col_name = self._column_name(col) col_name = self._column_name(col)
if col_name is None:
return tuple()
return col_name, self._serialize_value(col_name) return col_name, self._serialize_value(col_name)
def to_dict(self) -> dict: def to_dict(self) -> dict:
""" """
Returns the current entity as a flatten dictionary. Returns the current entity as a flatten dictionary.
""" """
return dict(self._column_to_pair(col) for col in self.columns) return dict(
self._column_to_pair(col)
for col in self.columns
if self._column_to_pair(col)
)
def to_json(self) -> dict: def to_json(self) -> dict:
""" """
@ -225,11 +251,6 @@ if 'entity' not in Base.metadata:
def _discover_entity_types(): def _discover_entity_types():
from platypush.context import get_plugin
logger = get_plugin('logger')
assert logger
for loader, modname, _ in pkgutil.walk_packages( for loader, modname, _ in pkgutil.walk_packages(
path=[str(pathlib.Path(__file__).parent.absolute())], path=[str(pathlib.Path(__file__).parent.absolute())],
prefix=__package__ + '.', prefix=__package__ + '.',