From a65f3272f002c7663c368aa4708ca706547e3bdb Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 20 Jun 2024 07:12:24 -0500 Subject: [PATCH] Fix union validation logic when `extra='allow'` (#1334) --- src/validators/model.rs | 8 ++--- src/validators/model_fields.rs | 3 ++ src/validators/validation_state.rs | 4 +++ tests/validators/test_union.py | 53 ++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 6 deletions(-) diff --git a/src/validators/model.rs b/src/validators/model.rs index fc4ca75d3..2c0cef6fd 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -204,7 +204,6 @@ impl Validator for ModelValidator { for field_name in validated_fields_set { fields_set.add(field_name)?; } - state.add_fields_set(fields_set.len()); } force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?; @@ -244,11 +243,9 @@ impl ModelValidator { }; force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?; force_setattr(py, self_instance, intern!(py, ROOT_FIELD), &output)?; - state.add_fields_set(fields_set.len()); } else { let (model_dict, model_extra, fields_set): (Bound, Bound, Bound) = output.extract(py)?; - state.add_fields_set(fields_set.len().unwrap_or(0)); set_model_attrs(self_instance, &model_dict, &model_extra, &fields_set)?; } self.call_post_init(py, self_instance.clone(), input, state.extra()) @@ -287,11 +284,10 @@ impl ModelValidator { }; force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?; force_setattr(py, &instance, intern!(py, ROOT_FIELD), output)?; - state.add_fields_set(fields_set.len()); } else { - let (model_dict, model_extra, val_fields_set) = output.extract(py)?; + let (model_dict, model_extra, val_fields_set): (Bound, Bound, Bound) = + output.extract(py)?; let fields_set = existing_fields_set.unwrap_or(&val_fields_set); - state.add_fields_set(fields_set.len().unwrap_or(0)); set_model_attrs(&instance, &model_dict, &model_extra, fields_set)?; } self.call_post_init(py, instance, input, state.extra()) diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index eda76056a..7ecd1d353 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -150,6 +150,7 @@ impl Validator for ModelFieldsValidator { let mut model_extra_dict_op: Option> = None; let mut errors: Vec = Vec::with_capacity(self.fields.len()); let mut fields_set_vec: Vec> = Vec::with_capacity(self.fields.len()); + let mut fields_set_count: usize = 0; // we only care about which keys have been used if we're iterating over the object for extra after // the first pass @@ -184,6 +185,7 @@ impl Validator for ModelFieldsValidator { Ok(value) => { model_dict.set_item(&field.name_py, value)?; fields_set_vec.push(field.name_py.clone_ref(py)); + fields_set_count += 1; } Err(ValError::Omit) => continue, Err(ValError::LineErrors(line_errors)) => { @@ -327,6 +329,7 @@ impl Validator for ModelFieldsValidator { Err(ValError::LineErrors(errors)) } else { let fields_set = PySet::new_bound(py, &fields_set_vec)?; + state.add_fields_set(fields_set_count); // if we have extra=allow, but we didn't create a dict because we were validating // from attributes, set it now so __pydantic_extra__ is always a dict if extra=allow diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index 92edfbbe9..b125cd316 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -18,6 +18,10 @@ pub enum Exactness { pub struct ValidationState<'a, 'py> { pub recursion_guard: &'a mut RecursionState, pub exactness: Option, + // This is used as a tie-breaking mechanism for union validation. + // Note: the count of the fields set is not always equivalent to the length of the + // `model_fields_set` attached to a model. `model_fields_set` includes extra fields + // when extra='allow', whereas this tally does not. pub fields_set_count: Option, // deliberately make Extra readonly extra: Extra<'a, 'py>, diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index f2d10f36b..7f9b4b424 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1280,3 +1280,56 @@ class ModelB: ) assert isinstance(result, ModelB) assert isinstance(result.b, SubModelW) + + +@pytest.mark.parametrize('extra_behavior', ['forbid', 'ignore', 'allow']) +def test_smart_union_extra_behavior(extra_behavior) -> None: + class Foo: + foo: str = 'foo' + + class Bar: + bar: str = 'bar' + + class Model: + x: Union[Foo, Bar] + + validator = SchemaValidator( + core_schema.model_schema( + Model, + core_schema.model_fields_schema( + fields={ + 'x': core_schema.model_field( + core_schema.union_schema( + [ + core_schema.model_schema( + Foo, + core_schema.model_fields_schema( + fields={ + 'foo': core_schema.model_field( + core_schema.with_default_schema(core_schema.str_schema(), default='foo') + ) + } + ), + extra_behavior=extra_behavior, + ), + core_schema.model_schema( + Bar, + core_schema.model_fields_schema( + fields={ + 'bar': core_schema.model_field( + core_schema.with_default_schema(core_schema.str_schema(), default='bar') + ) + } + ), + extra_behavior=extra_behavior, + ), + ] + ) + ) + } + ), + ) + ) + + assert isinstance(validator.validate_python({'x': {'foo': 'foo'}}).x, Foo) + assert isinstance(validator.validate_python({'x': {'bar': 'bar'}}).x, Bar)