From 026662f6b62ec6c6d49293a24f95511888518043 Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Sun, 26 Mar 2023 03:47:44 +0200 Subject: [PATCH] Added base schema for Marshmallow dataclasses. --- platypush/schemas/dataclasses/__init__.py | 56 +++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 platypush/schemas/dataclasses/__init__.py diff --git a/platypush/schemas/dataclasses/__init__.py b/platypush/schemas/dataclasses/__init__.py new file mode 100644 index 0000000000..9ca4c3185c --- /dev/null +++ b/platypush/schemas/dataclasses/__init__.py @@ -0,0 +1,56 @@ +from datetime import date, datetime +from uuid import UUID + +from marshmallow import ( + EXCLUDE, + Schema, + fields, + post_dump, +) + +from .. import Date, DateTime + + +class DataClassSchema(Schema): + """ + Base schema class for data classes that support Marshmallow schemas. + """ + + TYPE_MAPPING = { + date: Date, + datetime: DateTime, + UUID: fields.UUID, + } + + # pylint: disable=too-few-public-methods + class Meta: + """ + Exclude unknown fields. + """ + + unknown = EXCLUDE + + def _get_field(self, key: str) -> fields.Field: + """ + Returns the matching field by either name or data_key. + """ + if key in self.fields: + return self.fields[key] + + matching_fields = [f for f in self.fields.values() if key == f.data_key] + + assert ( + len(matching_fields) == 1 + ), f'Could not find field {key} in {self.__class__.__name__}' + + return matching_fields[0] + + @post_dump + def post_dump(self, data: dict, **__) -> dict: + # Use data_key parameters only for load + new_data = {} + for key, value in data.items(): + field = self._get_field(key) + new_data[field.name if field.data_key is not None else key] = value + + return new_data