Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add 'millisecond' option to ser_json_timedelta config parameter #1427

Merged
merged 23 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class CoreConfig(TypedDict, total=False):
# fields related to float fields only
allow_inf_nan: bool # default: True
# the config options are used to customise serialization to JSON
ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601'
ser_json_timedelta: Literal['iso8601', 'float', 'millisecond'] # default: 'iso8601'
ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
ser_json_inf_nan: Literal['null', 'constants', 'strings'] # default: 'null'
val_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
Expand Down
24 changes: 23 additions & 1 deletion src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::str::{from_utf8, FromStr, Utf8Error};
use base64::Engine;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDelta, PyDict, PyString};
use pyo3::types::{PyDelta, PyDict, PyFloat, PyString};

use serde::ser::Error;

Expand Down Expand Up @@ -89,6 +89,7 @@ serialization_mode! {
"ser_json_timedelta",
Iso8601 => "iso8601",
Float => "float",
Millisecond => "millisecond"
}

serialization_mode! {
Expand Down Expand Up @@ -125,6 +126,14 @@ impl TimedeltaMode {
let seconds = Self::total_seconds(&py_timedelta)?;
Ok(seconds.into_py(py))
}
Self::Millisecond => {
// convert to int via a py timedelta not duration since we know this this case the input would have
// been a py timedelta
ollz272 marked this conversation as resolved.
Show resolved Hide resolved
let py_timedelta = either_delta.try_into_py(py)?;
let seconds: f64 = Self::total_seconds(&py_timedelta)?.extract()?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a better way to do this which is alluding me, maybe we could do the multiplication in python? 🤷🏻‍♂️

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable enough - multiplication here should be faster.

Can we pull out some of the shared logic into a function (both for Float and for Millisecond) and repeat that across the various branches?

Copy link
Contributor Author

@ollz272 ollz272 Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, maybe, a couple of them use logic like:

let py_timedelta = either_delta.try_into_py(py)?;
let seconds: f64 = Self::total_seconds(&py_timedelta)?.extract()?;

However then the serializer needs to wrap these in map_err calls, so can't be reused there..

Maybe a rust wizard could help with some clever refactoring im not seeing 🧙🏻

Copy link
Contributor

@davidhewitt davidhewitt Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To do the multiplication in Python would be something like

let object = Self::total_seconds(&py_timedelta)?.mul(1000)?;

... this requires creating a Python integer 1000, which for best performance we might want to consider caching.

But there is another inefficiency here (and in the other cases) which is that the call to try_into_py creates a Python timedelta object which then gets thrown away immediately to convert to float. It will probably be better in all cases to use .to_duration(), which will avoid the temporary Python object in the case of a Duration Rust value.

The best option would be to go further and to add a .total_seconds() method to EitherTimedelta which extracts the fractional seconds from whatever state the EitherTimedelta is currently in (doing the most efficient thing for each case) and then doing the multiplication in Rust.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, yeah neat suggestion on the EitherTimedelta

I guess we'd want something like this:

impl<'a> EitherTimedelta<'a> {
    ....

    pub fn total_seconds(&self, py: Python<'a>) -> f64 {
        match self {
            Self::Raw(timedelta) => ...
            Self::PyExact(py_timedelta) => ...
            Self::PySubclass(py_timedelta) => ...
        }
    }
}

And then we have two cases: Duration and PyDelta to deal with. Looks like we have some methods around that would involve getting the py_timedelta into a Duration object, so then its just Duration we need to deal with. However (maybe im reading the docs wrong!) i can't seem to see a method on there that returns the total_seconds as a float?

Probably missed something here!

Copy link
Contributor Author

@ollz272 ollz272 Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    pub fn total_seconds(&self) -> f64 {
        match self {
            Self::Raw(timedelta) => ...,
            Self::PyExact(py_timedelta) => intern!(py_timedelta.py(), "total_seconds"))?.extract?
            Self::PySubclass(py_timedelta) => intern!(py_timedelta.py(), "total_seconds"))?.extract?
        }
    }

Something like this? Not 100% sure what to do with the Raw case, but doing what we do currently and extracting from the python object by calling "total_seconds" on it feels like the best way in this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest looking at the to_duration method to see how that gets days / seconds / microseconds out of the Python value, as that should have the most efficient implementation already set up to get the data out (and then you can do a bit of arithmetic). In general calling a Python method will be slow-ish, so there might be a faster step for PyExact case.

For Raw, you need to work with the speedate::Duration object, can probably get a timestamp value and microseconds and combine those.

let object: Bound<PyFloat> = PyFloat::new_bound(py, seconds * 1000.0);
Ok(object.into_py(py))
}
}
}

Expand All @@ -139,6 +148,12 @@ impl TimedeltaMode {
let seconds: f64 = Self::total_seconds(&py_timedelta)?.extract()?;
Ok(seconds.to_string().into())
}
Self::Millisecond => {
let py_timedelta = either_delta.try_into_py(py)?;
let seconds: f64 = Self::total_seconds(&py_timedelta)?.extract()?;
let milliseconds: f64 = seconds * 1000.0;
Ok(milliseconds.to_string().into())
}
}
}

Expand All @@ -159,6 +174,13 @@ impl TimedeltaMode {
let seconds: f64 = seconds.extract().map_err(py_err_se_err)?;
serializer.serialize_f64(seconds)
}
Self::Millisecond => {
let py_timedelta = either_delta.try_into_py(py).map_err(py_err_se_err)?;
let seconds = Self::total_seconds(&py_timedelta).map_err(py_err_se_err)?;
let seconds: f64 = seconds.extract().map_err(py_err_se_err)?;
let milliseconds: f64 = seconds * 1000.0;
serializer.serialize_f64(milliseconds)
}
}
}
}
Expand Down
24 changes: 24 additions & 0 deletions tests/serializers/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,30 @@ def test_any_config_timedelta_float_faction():
assert s.to_json({one_half_s: 'foo'}) == b'{"1.5":"foo"}'


def test_any_config_timedelta_millisecond():
ollz272 marked this conversation as resolved.
Show resolved Hide resolved
s = SchemaSerializer(core_schema.any_schema(), config={'ser_json_timedelta': 'millisecond'})
h2 = timedelta(hours=2)
assert s.to_python(h2) == h2
assert s.to_python(h2, mode='json') == 7200000.0
assert s.to_json(h2) == b'7200000.0'

assert s.to_python({h2: 'foo'}) == {h2: 'foo'}
assert s.to_python({h2: 'foo'}, mode='json') == {'7200000': 'foo'}
assert s.to_json({h2: 'foo'}) == b'{"7200000":"foo"}'


def test_any_config_timedelta_millisecond_fraction():
s = SchemaSerializer(core_schema.any_schema(), config={'ser_json_timedelta': 'millisecond'})
h2 = timedelta(seconds=1.5)
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
assert s.to_python(h2) == h2
assert s.to_python(h2, mode='json') == 1500.0
assert s.to_json(h2) == b'1500.0'

assert s.to_python({h2: 'foo'}) == {h2: 'foo'}
assert s.to_python({h2: 'foo'}, mode='json') == {'1500': 'foo'}
assert s.to_json({h2: 'foo'}) == b'{"1500":"foo"}'


def test_recursion(any_serializer):
v = [1, 2]
v.append(v)
Expand Down