platypush/platypush/schemas/dataclasses/__init__.py

77 lines
1.7 KiB
Python

from dataclasses import field
from datetime import date, datetime
from uuid import UUID
from marshmallow import (
EXCLUDE,
Schema,
fields,
post_dump,
pre_load,
)
from marshmallow.validate import Range
from .. import Date, DateTime
def percent_field(**kwargs):
"""
Field used to model percentage float fields between 0 and 1.
"""
return field(
default_factory=float,
metadata={
'validate': Range(min=0, max=1),
**kwargs,
},
)
class DataClassSchema(Schema):
"""
Base schema class for data classes that support Marshmallow schemas.
"""
TYPE_MAPPING = {
date: Date,
datetime: DateTime,
UUID: fields.UUID,
}
# pylint: disable=too-few-public-methods
class Meta:
"""
Exclude unknown fields.
"""
unknown = EXCLUDE
def _get_field(self, key: str) -> fields.Field:
"""
Returns the matching field by either name or data_key.
"""
if key in self.fields:
return self.fields[key]
matching_fields = [f for f in self.fields.values() if key == f.data_key]
assert (
len(matching_fields) == 1
), f'Could not find field {key} in {self.__class__.__name__}'
return matching_fields[0]
@pre_load
def pre_load(self, data, **__) -> dict:
return data
@post_dump
def post_dump(self, data: dict, **__) -> dict:
# Use data_key parameters only for load
new_data = {}
for key, value in data.items():
field = self._get_field(key)
new_data[field.name if field.data_key is not None else key] = value
return new_data