Skip to content

Commit

Permalink
feat: add 'millisecond' option to ser_json_timedelta config parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
ollz272 committed Aug 30, 2024
1 parent d93e6b1 commit 0107915
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class CoreConfig(TypedDict, total=False):
# fields related to float fields only
allow_inf_nan: bool # default: True
# the config options are used to customise serialization to JSON
ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601'
ser_json_timedelta: Literal['iso8601', 'float', 'millisecond'] # default: 'iso8601'
ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
ser_json_inf_nan: Literal['null', 'constants', 'strings'] # default: 'null'
val_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
Expand Down
24 changes: 23 additions & 1 deletion src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::str::{from_utf8, FromStr, Utf8Error};
use base64::Engine;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDelta, PyDict, PyString};
use pyo3::types::{PyDelta, PyDict, PyFloat, PyString};

use serde::ser::Error;

Expand Down Expand Up @@ -89,6 +89,7 @@ serialization_mode! {
"ser_json_timedelta",
Iso8601 => "iso8601",
Float => "float",
Millisecond => "millisecond"
}

serialization_mode! {
Expand Down Expand Up @@ -125,6 +126,14 @@ impl TimedeltaMode {
let seconds = Self::total_seconds(&py_timedelta)?;
Ok(seconds.into_py(py))
}
Self::Millisecond => {
// convert to int via a py timedelta not duration since we know this this case the input would have
// been a py timedelta
let py_timedelta = either_delta.try_into_py(py)?;
let seconds: f64 = Self::total_seconds(&py_timedelta)?.extract()?;
let object: Bound<PyFloat> = PyFloat::new_bound(py, seconds * 1000.0);
Ok(object.into_py(py))
}
}
}

Expand All @@ -139,6 +148,12 @@ impl TimedeltaMode {
let seconds: f64 = Self::total_seconds(&py_timedelta)?.extract()?;
Ok(seconds.to_string().into())
}
Self::Millisecond => {
let py_timedelta = either_delta.try_into_py(py)?;
let seconds: f64 = Self::total_seconds(&py_timedelta)?.extract()?;
let milliseconds: f64 = seconds * 1000.0;
Ok(milliseconds.to_string().into())
}
}
}

Expand All @@ -159,6 +174,13 @@ impl TimedeltaMode {
let seconds: f64 = seconds.extract().map_err(py_err_se_err)?;
serializer.serialize_f64(seconds)
}
Self::Millisecond => {
let py_timedelta = either_delta.try_into_py(py).map_err(py_err_se_err)?;
let seconds = Self::total_seconds(&py_timedelta).map_err(py_err_se_err)?;
let seconds: f64 = seconds.extract().map_err(py_err_se_err)?;
let milliseconds: f64 = seconds * 1000.0;
serializer.serialize_f64(milliseconds)
}
}
}
}
Expand Down
24 changes: 24 additions & 0 deletions tests/serializers/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,30 @@ def test_any_config_timedelta_float_faction():
assert s.to_json({one_half_s: 'foo'}) == b'{"1.5":"foo"}'


def test_any_config_timedelta_millisecond():
s = SchemaSerializer(core_schema.any_schema(), config={'ser_json_timedelta': 'millisecond'})
h2 = timedelta(hours=2)
assert s.to_python(h2) == h2
assert s.to_python(h2, mode='json') == 7200000.0
assert s.to_json(h2) == b'7200000.0'

assert s.to_python({h2: 'foo'}) == {h2: 'foo'}
assert s.to_python({h2: 'foo'}, mode='json') == {'7200000': 'foo'}
assert s.to_json({h2: 'foo'}) == b'{"7200000":"foo"}'


def test_any_config_timedelta_millisecond_fraction():
s = SchemaSerializer(core_schema.any_schema(), config={'ser_json_timedelta': 'millisecond'})
h2 = timedelta(seconds=1.5)
assert s.to_python(h2) == h2
assert s.to_python(h2, mode='json') == 1500.0
assert s.to_json(h2) == b'1500.0'

assert s.to_python({h2: 'foo'}) == {h2: 'foo'}
assert s.to_python({h2: 'foo'}, mode='json') == {'1500': 'foo'}
assert s.to_json({h2: 'foo'}) == b'{"1500":"foo"}'


def test_recursion(any_serializer):
v = [1, 2]
v.append(v)
Expand Down

0 comments on commit 0107915

Please sign in to comment.