Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement in generic schema generation #1371

Open
adriangb opened this issue Jul 17, 2024 · 1 comment
Open

Improvement in generic schema generation #1371

adriangb opened this issue Jul 17, 2024 · 1 comment
Labels

Comments

@adriangb
Copy link
Member

Currently if you have a generic model and later substitute it we re-generate the entire schema by walking all of the types. This makes schema generation very expensive.

I propose that instead we wrap types as follows:

generic_model = {
    "type": "definitions",
    "schema": {
        "type": "model",
        "fields": {
            "foo": {
                "type": "definitions-ref",
                "ref": "T::some-ref",
            }
        }
    },
    "definitions": [
        {
            "type": "any",
            "ref": "T::some-ref",
        }
    ]
}

concrete_model = {
    "type": "definitions",
    "schema": generic_model,
    "definitions": [
        {
            "type": "int",
            "ref": "T::some-ref",
        }
    ]
}

This has the potential to greatly improve generics schema generation performance.

@adriangb
Copy link
Member Author

adriangb commented Aug 8, 2024

@sydney-runkle created https://github.com/pydantic/pydantic-core/tree/typevar-schema-changes today which has the changes to enable this to work in pydantic-core.

I spoke with @dmontagu and he brought up two important points:

  1. Performance vs. analyzing the schema and simplifying it after replacing type vars.
  2. Issues with nested generics.

For (1) I ran this benchmark:

from typing_extensions import TypeVar

from pydantic_core import core_schema as cs, SchemaValidator

# schema for a generic list of T
T = TypeVar('T')
X = list[T]
generic_schema = cs.definitions_schema(
    cs.list_schema(cs.definition_reference_schema('T')),
    [
        cs.any_schema(ref='T')
    ]
)

# substitute T with float
Y = X[float]

# the result of this proposal
concrete_schema = cs.definitions_schema(
    generic_schema,
    [
        cs.float_schema(ref='T')
    ]
)
v2 = SchemaValidator(concrete_schema).validate_python

# an "ideal" simplified schema that requires re-writing the entire schema
simplified_concrete_schemas = cs.list_schema(cs.int_schema())
v3 = SchemaValidator(simplified_concrete_schemas).validate_python

%timeit v2([1.0, 2.0, 3.0])
%timeit v3([1.0, 2.0, 3.0])

Which gave me 167 ns ± 2.85 ns and 188 ns ± 3.68 ns so a ~12% slowdown. Maybe that's okay, maybe it's not. Maybe we can make it less of a performance drag to use definitions.

Here's the example for (2) that is currently broken with this proposal / branch:

from __future__ import annotations
from ctypes import cast

from typing_extensions import TypeVar

from pydantic_core import core_schema as cs, SchemaValidator


class MyGeneric[T]:
    int_field: MyGeneric[int] | int
    str_field: MyGeneric[str] | str
    generic_field: MyGeneric[T] | T

X = MyGeneric[float]

x_schema = cs.definitions_schema(
    cs.definitions_schema(
        cs.definition_reference_schema(schema_ref='MyGeneric'),
        [
            cs.any_schema(ref='T'),
            cs.model_schema(
                MyGeneric,
                cs.model_fields_schema(
                    {
                        'int_field': cs.model_field(
                            cs.union_schema(
                                [
                                    cs.int_schema(),
                                    cs.definitions_schema(
                                        cs.definition_reference_schema(schema_ref='MyGeneric'),
                                        [
                                            cs.int_schema(ref='T'),
                                        ]
                                    )
                                ]
                            )
                        ),
                        'str_field': cs.model_field(
                            cs.union_schema(
                                [
                                    cs.str_schema(),
                                    cs.definitions_schema(
                                        cs.definition_reference_schema(schema_ref='MyGeneric'),
                                        [
                                            cs.str_schema(ref='T'),
                                        ]
                                    )
                                ]
                            )
                        ),
                        'generic_field': cs.model_field(
                            cs.union_schema(
                                [
                                    cs.definition_reference_schema(schema_ref='T'),
                                    cs.definitions_schema(
                                        cs.definition_reference_schema(schema_ref='MyGeneric'),
                                        [
                                            cs.definition_reference_schema(schema_ref='T', ref='T'),
                                        ]
                                    ),
                                ]
                            )
                        )
                    }
                ),
                ref='MyGeneric',
            ),
        ]
    ),
    [
        cs.float_schema(ref='T')
    ],
)

x_validator = SchemaValidator(x_schema)


outer = {
    'int_field': {'int_field': 1, 'str_field': 'a', 'generic_field': 2},
    'str_field': {'int_field': 1, 'str_field': 'a', 'generic_field': '2'},
    'generic_field': 3,
}
x: X = x_validator.validate_python(outer)
assert isinstance(x.generic_field, float)
assert isinstance(x.int_field, MyGeneric), x.int_field
assert isinstance(x.int_field.generic_field, int), x.int_field.generic_field  # fails, is actually float
assert isinstance(x.str_field, MyGeneric), x.str_field
assert isinstance(x.str_field.generic_field, str), x.str_field.generic_field  # fails, is actually float

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants