Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic validator #1121

Merged
merged 17 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/concepts/function-modifiers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ pandera support
Hamilton has a pandera plugin for data validation that you can install with ``pip install sf-hamilton[pandera]``. Then, you can pass a pandera schema (for DataFrame or Series) to ``@check_output(schema=...)``.


pydantic support
~~~~~~~~~~~~~~~~

Hamilton also supports data validation of pydantic models, which can be enabled with ``pip install sf-hamilton[pydantic]``. With pydantic installed, you can pass any subclass of the pydantic base model to ``@check_output(model=...)``. Pydantic validation is performed in strict mode, meaning that raw values will not be coerced to the model's types. For more information on strict mode see the `pydantic docs <https://docs.pydantic.dev/latest/concepts/strict_mode/>`_.


Split node output into *n* nodes
--------------------------------

Expand Down
10 changes: 7 additions & 3 deletions docs/reference/decorators/check_output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ Note that you can also specify custom decorators using the ``@check_output_custo
See `data_quality <https://github.com/dagworks-inc/hamilton/blob/main/data\_quality.md>`_ for more information on
available validators and how to build custom ones.

Note we also have a plugin that allows you to use pandera. There are two ways to access it:
1. `@check_output(schema=pandera_schema)`
2. `@h_pandera.check_output()` on a function that declares a typed pandera dataframe as an output
Note we also have a plugins that allow for validation with the pandera and pydantic libraries. There are two ways to access these:

1. ``@check_output(schema=pandera_schema)`` or ``@check_output(model=pydantic_model)``
2. ``@h_pandera.check_output()`` or ``@h_pydantic.check_output()`` on the function that declares either a typed dataframe or a pydantic model.

----

Expand All @@ -43,3 +44,6 @@ Note we also have a plugin that allows you to use pandera. There are two ways to

.. autoclass:: hamilton.plugins.h_pandera.check_output
:special-members: __init__

.. autoclass:: hamilton.plugins.h_pydantic.check_output
:special-members: __init__
17 changes: 17 additions & 0 deletions hamilton/data_quality/default_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,23 @@ def _append_pandera_to_default_validators():
_append_pandera_to_default_validators()


def _append_pydantic_to_default_validators():
"""Utility method to append pydantic validators as needed"""
try:
import pydantic # noqa: F401
except ModuleNotFoundError:
logger.debug(
"Cannot import pydantic from pydantic_validators. Run pip install sf-hamilton[pydantic] if needed."
)
return
from hamilton.data_quality import pydantic_validators

AVAILABLE_DEFAULT_VALIDATORS.extend(pydantic_validators.PYDANTIC_VALIDATORS)


_append_pydantic_to_default_validators()


def resolve_default_validators(
output_type: Type[Type],
importance: str,
Expand Down
60 changes: 60 additions & 0 deletions hamilton/data_quality/pydantic_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any, Type

from pydantic import BaseModel, TypeAdapter, ValidationError

from hamilton.data_quality import base
from hamilton.htypes import custom_subclass_check


class PydanticModelValidator(base.BaseDefaultValidator):
"""Pydantic model compatibility validator

Note that this validator uses pydantic's strict mode, which does not allow for
cswartzvi marked this conversation as resolved.
Show resolved Hide resolved
coercion of data. This means that if an object does not exactly match the reference
type, it will fail validation, regardless of whether it could be coerced into the
correct type.

:param model: Pydantic model to validate against
:param importance: Importance of the validator, possible values "warn" and "fail"
:param arbitrary_types_allowed: Whether arbitrary types are allowed in the model
"""

def __init__(self, model: Type[BaseModel], importance: str):
super(PydanticModelValidator, self).__init__(importance)
self.model = model
self._model_adapter = TypeAdapter(model)

@classmethod
def applies_to(cls, datatype: Type[Type]) -> bool:
# In addition to checking for a subclass of BaseModel, we also check for dict
# as this is the standard 'de-serialized' format of pydantic models in python
return custom_subclass_check(datatype, BaseModel) or custom_subclass_check(datatype, dict)

def description(self) -> str:
return "Validates that the returned object is compatible with the specified pydantic model"

def validate(self, data: Any) -> base.ValidationResult:
try:
# Currently, validate can not alter the output data, so we must use
# strict=True. The downside to this is that data that could be coerced
# into the correct type will fail validation.
self._model_adapter.validate_python(data, strict=True)
elijahbenizzy marked this conversation as resolved.
Show resolved Hide resolved
except ValidationError as e:
return base.ValidationResult(
passes=False, message=str(e), diagnostics={"model_errors": e.errors()}
)
return base.ValidationResult(
passes=True,
message=f"Data passes pydantic check for model {str(self.model)}",
)

@classmethod
def arg(cls) -> str:
return "model"

@classmethod
def name(cls) -> str:
return "pydantic_validator"


PYDANTIC_VALIDATORS = [PydanticModelValidator]
92 changes: 92 additions & 0 deletions hamilton/plugins/h_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import List

from pydantic import BaseModel

from hamilton import node
from hamilton.data_quality import base as dq_base
from hamilton.function_modifiers import InvalidDecoratorException
from hamilton.function_modifiers import base as fm_base
from hamilton.function_modifiers import check_output as base_check_output
from hamilton.function_modifiers.validation import BaseDataValidationDecorator
from hamilton.htypes import custom_subclass_check


class check_output(BaseDataValidationDecorator):
cswartzvi marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
importance: str = dq_base.DataValidationLevel.WARN.value,
target: fm_base.TargetType = None,
):
"""Specific output-checker for pydantic models. This decorator utilizes the output type of
the function, which can be any subclass of pydantic.BaseModel. The function output must
be declared with a type hint.

:param model: The pydantic model to use for validation. If this is not provided, then the output type of the function is used.
:param importance: Importance level (either "warn" or "fail") -- see documentation for check_output for more details.
:param target: The target of the decorator -- see documentation for check_output for more details.

Here is an example of how to use this decorator with a function that returns a pydantic model:

.. code-block:: python

from pydantic import BaseModel
from hamilton.plugins import h_pydantic

class MyModel(BaseModel):
a: int
b: float
c: str

@h_pydantic.check_output()
def foo() -> MyModel:
return MyModel(a=1, b=2.0, c="hello")

Alternatively, you can return a dictionary from the function (type checkers will probably
complain about this):

.. code-block:: python

from pydantic import BaseModel
from hamilton.plugins import h_pydantic

class MyModel(BaseModel):
a: int
b: float
c: str

@h_pydantic.check_output()
def foo() -> MyModel:
return {"a": 1, "b": 2.0, "c": "hello"}

Note, that because we do not (yet) support modification of the output, the validation is
performed in strict mode, meaning that no data coercion is performed. For example, the
following function will *fail* validation:

.. code-block:: python

from pydantic import BaseModel
from hamilton.plugins import h_pydantic

class MyModel(BaseModel):
a: int # Defined as an int

@h_pydantic.check_output() # This will fail validation!
def foo() -> MyModel:
return MyModel(a="1") # Assigned as a string

For more information about strict mode see the pydantic docs: https://docs.pydantic.dev/latest/concepts/strict_mode/

"""
super(check_output, self).__init__(target)
self.importance = importance
self.target = target

def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]:
output_type = node_to_validate.type
if not custom_subclass_check(output_type, BaseModel):
raise InvalidDecoratorException(
f"Output of function {node_to_validate.name} must be a Pydantic model"
)
return base_check_output(
importance=self.importance, model=output_type, target_=self.target
).get_validators(node_to_validate)
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ docs = [
"pillow",
"polars",
"pyarrow >= 1.0.0",
"pydantic >=2.0",
"pyspark",
"openlineage-python",
"PyYAML",
Expand All @@ -99,6 +100,7 @@ packaging = [
"build",
]
pandera = ["pandera"]
pydantic = ["pydantic>=2.0"]
pyspark = [
# we have to run these dependencies because Spark does not check to ensure the right target was called
"pyspark[pandas_on_spark,sql]"
Expand Down Expand Up @@ -129,6 +131,7 @@ test = [
"plotly",
"polars",
"pyarrow",
"pydantic >=2.0",
"pyreadstat", # for SPSS data loader
"pytest",
"pytest-asyncio",
Expand All @@ -144,10 +147,7 @@ test = [
]
tqdm = ["tqdm"]
ui = ["sf-hamilton-ui"]
vaex = [
"pydantic<2.0", # because of https://github.com/vaexio/vaex/issues/2384
"vaex"
]
vaex = ["vaex"]
visualization = ["graphviz", "networkx"]

[project.entry-points.console_scripts]
Expand Down
1 change: 1 addition & 0 deletions tests/integrations/pydantic/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Additional requirements on top of hamilton...pydantic
Loading