Skip to content

Commit

Permalink
feat: Serialize computed field without a model
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelhly committed Oct 15, 2023
1 parent b51105a commit 0929f62
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 5 deletions.
31 changes: 27 additions & 4 deletions src/serializers/computed_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ impl ComputedFields {
// Do not serialize computed fields
return Ok(());
}
for computed_fields in &self.0 {
computed_fields.to_python(model, output_dict, filter, include, exclude, extra)?;
for computed_field in &self.0 {
computed_field.to_python(model, output_dict, filter, include, exclude, extra)?;
}

Ok(())
}

Expand Down Expand Up @@ -102,6 +103,7 @@ struct ComputedField {
serializer: CombinedSerializer,
alias: String,
alias_py: Py<PyString>,
has_ser_func: bool,
}

impl ComputedField {
Expand All @@ -123,6 +125,7 @@ impl ComputedField {
serializer,
alias: alias_py.extract()?,
alias_py: alias_py.into_py(py),
has_ser_func: has_ser_function(return_schema),
})
}

Expand All @@ -139,7 +142,11 @@ impl ComputedField {
let property_name_py = self.property_name_py.as_ref(py);

if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? {
let next_value = model.getattr(property_name_py)?;
let next_value = if self.has_ser_func {
model
} else {
model.getattr(property_name_py)?
};

let value = self
.serializer
Expand Down Expand Up @@ -179,7 +186,11 @@ impl<'py> Serialize for ComputedFieldSerializer<'py> {
fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let py = self.model.py();
let property_name_py = self.computed_field.property_name_py.as_ref(py);
let next_value = self.model.getattr(property_name_py).map_err(py_err_se_err)?;
let next_value = if self.computed_field.has_ser_func {
self.model
} else {
self.model.getattr(property_name_py).map_err(py_err_se_err)?
};
let s = PydanticSerializer::new(
next_value,
&self.computed_field.serializer,
Expand All @@ -190,3 +201,15 @@ impl<'py> Serialize for ComputedFieldSerializer<'py> {
s.serialize(serializer)
}
}

fn has_ser_function(schema: &PyDict) -> bool {
let py = schema.py();
let ser_schema = schema
.get_as::<&PyDict>(intern!(py, "serialization"))
.unwrap_or_default();

match ser_schema {
Some(s) => s.contains(intern!(py, "function")).unwrap_or_default(),
None => false,
}
}
71 changes: 70 additions & 1 deletion tests/serializers/test_typed_dict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict
from typing import Any, Dict, List

import pytest
from dirty_equals import IsStrictDict
Expand Down Expand Up @@ -333,3 +333,72 @@ def test_extra_custom_serializer():
m = {'extra': 'extra'}

assert s.to_python(m) == {'extra': 'extra bam!'}


def test_computed_fields_with_plain_serializer_function():
def ser_x(v: dict):
two = v['0'] + v['1'] + 1
return two

schema = core_schema.typed_dict_schema(
{
'0': core_schema.typed_dict_field(core_schema.int_schema()),
'1': core_schema.typed_dict_field(core_schema.int_schema()),
},
computed_fields=[
core_schema.computed_field(
'2', core_schema.int_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_x))
)
],
)
s = SchemaSerializer(schema)
value = {'0': 0, '1': 1}
assert s.to_python(value) == {'0': 0, '1': 1, '2': 2}

def ser_foo(_v: dict):
return 'bar'

schema = core_schema.typed_dict_schema(
{},
computed_fields=[
core_schema.computed_field(
'foo', core_schema.str_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_foo))
)
],
)
s = SchemaSerializer(schema)
assert s.to_python({}) == {'foo': 'bar'}


def test_computed_fields_with_warpped_serializer_function():
def ser_to_upper(string_arr: List[str]) -> List[str]:
return [s.upper() for s in string_arr]

def ser_columns(v: dict, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str:
column_keys = serializer([key for key in v.keys()])
return column_keys

schema = core_schema.typed_dict_schema(
{
'one': core_schema.typed_dict_field(core_schema.int_schema()),
'two': core_schema.typed_dict_field(core_schema.int_schema()),
'three': core_schema.typed_dict_field(core_schema.int_schema()),
},
computed_fields=[
core_schema.computed_field(
'columns',
core_schema.int_schema(
serialization=core_schema.wrap_serializer_function_ser_schema(
ser_columns,
info_arg=True,
schema=core_schema.list_schema(
serialization=core_schema.plain_serializer_function_ser_schema(ser_to_upper)
),
)
),
)
],
)
s = SchemaSerializer(schema)
value = {'one': 1, 'two': 2, 'three': 3}
assert s.to_python(value) == {'one': 1, 'two': 2, 'three': 3, 'columns': ['ONE', 'TWO', 'THREE']}

0 comments on commit 0929f62

Please sign in to comment.