diff --git a/hamilton/experimental/h_databackends.py b/hamilton/experimental/h_databackends.py index 6b287ea66..481d4dcaa 100644 --- a/hamilton/experimental/h_databackends.py +++ b/hamilton/experimental/h_databackends.py @@ -28,7 +28,7 @@ def _(df: h_databackends.AbstractIbisDataFrame) -> pyarrow.Schema: import importlib import inspect -from typing import Tuple, Type, Union +from typing import Tuple from hamilton.experimental.databackend import AbstractBackend @@ -128,7 +128,7 @@ class AbstractModinDataFrame(AbstractBackend): _backends = [("modin.pandas", "DataFrame")] -def register_backends() -> Tuple[Type, Type]: +def register_backends() -> Tuple[Tuple[type], Tuple[type]]: """Register databackends defined in this module that include `DataFrame` and `Column` in their class name """ @@ -143,8 +143,8 @@ def register_backends() -> Tuple[Type, Type]: abstract_column_types.add(cls) # Union[tuple()] creates a Union type object - DATAFRAME_TYPES = Union[tuple(abstract_dataframe_types)] - COLUMN_TYPES = Union[tuple(abstract_column_types)] + DATAFRAME_TYPES = tuple(abstract_dataframe_types) + COLUMN_TYPES = tuple(abstract_column_types) return DATAFRAME_TYPES, COLUMN_TYPES diff --git a/hamilton/plugins/h_schema.py b/hamilton/plugins/h_schema.py index 3d44e375e..4b748bd43 100644 --- a/hamilton/plugins/h_schema.py +++ b/hamilton/plugins/h_schema.py @@ -278,7 +278,9 @@ def _(df: h_databackends.AbstractIbisDataFrame, **kwargs) -> pyarrow.Schema: # ongoing polars discussion: https://github.com/pola-rs/polars/issues/15600 -def get_dataframe_schema(df: h_databackends.DATAFRAME_TYPES, node: HamiltonNode) -> pyarrow.Schema: +def get_dataframe_schema( + df: Union[h_databackends.DATAFRAME_TYPES], node: HamiltonNode +) -> pyarrow.Schema: """Get pyarrow schema of a node result and store node metadata on the pyarrow schema.""" schema = _get_arrow_schema(df) metadata = dict( diff --git a/tests/experimental/__init__.py b/tests/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/experimental/test_h_databackends.py b/tests/experimental/test_h_databackends.py new file mode 100644 index 000000000..15015ca16 --- /dev/null +++ b/tests/experimental/test_h_databackends.py @@ -0,0 +1,43 @@ +import pandas as pd + +from hamilton.experimental import h_databackends + + +def test_isinstance_dataframe(): + value = pd.DataFrame() + assert isinstance(value, h_databackends.DATAFRAME_TYPES) + + +def test_issubclass_dataframe(): + class_ = pd.DataFrame + assert issubclass(class_, h_databackends.DATAFRAME_TYPES) + + +def test_not_isinstance_dataframe(): + value = 6 + assert not isinstance(value, h_databackends.DATAFRAME_TYPES) + + +def test_not_issubclass_dataframe(): + class_ = int + assert not issubclass(class_, h_databackends.DATAFRAME_TYPES) + + +def test_isinstance_column(): + value = pd.Series() + assert isinstance(value, h_databackends.COLUMN_TYPES) + + +def test_issubclass_column(): + class_ = pd.Series + assert issubclass(class_, h_databackends.COLUMN_TYPES) + + +def test_not_isinstance_column(): + value = 6 + assert not isinstance(value, h_databackends.COLUMN_TYPES) + + +def test_not_issubclass_column(): + class_ = int + assert not issubclass(class_, h_databackends.COLUMN_TYPES) diff --git a/tests/plugins/test_h_schema.py b/tests/plugins/test_h_schema.py index b265316d2..512902b0f 100644 --- a/tests/plugins/test_h_schema.py +++ b/tests/plugins/test_h_schema.py @@ -1,9 +1,11 @@ import json import pathlib +import pandas as pd import pyarrow import pytest +from hamilton import graph_types from hamilton.plugins import h_schema @@ -140,7 +142,6 @@ def test_schema_edited_schema_metadata(schema1: pyarrow.Schema, metadata1: dict, assert schema_diff[h_schema.SCHEMA_METADATA_FIELD].value["key"].diff == h_schema.Diff.UNEQUAL human_readable_diff = h_schema.human_readable_diff(schema_diff) - print(human_readable_diff) assert human_readable_diff == { h_schema.SCHEMA_METADATA_FIELD: {"key": {"cur": "value1", "ref": "value2"}} @@ -227,3 +228,81 @@ def test_save_schema_to_disk(schema1: pyarrow.Schema, tmp_path: pathlib.Path): h_schema.save_schema(path=schema_path, schema=schema1) loaded_schema = pyarrow.ipc.read_schema(schema_path) assert schema1.equals(loaded_schema) + + +def test_get_dataframe_schema(): + def foo(x: pd.DataFrame) -> pd.DataFrame: + """doc""" + return x + + version = graph_types.hash_source_code(foo, strip=True) + node = graph_types.HamiltonNode( + name=foo.__name__, + type=pd.DataFrame, + documentation=foo.__doc__, + tags={}, + is_external_input=False, + originating_functions=(foo,), + required_dependencies=set(), + optional_dependencies=set(), + ) + df = pd.DataFrame({"a": [0, 1], "b": [True, False]}) + + expected_schema = pyarrow.schema( + [ + ("a", "int64"), + ("b", "bool"), + ] + ) + expected_metadata = { + b"name": foo.__name__.encode(), + b"documentation": foo.__doc__.encode(), + b"version": version.encode(), + } + + schema = h_schema.get_dataframe_schema(df, node) + + assert schema.equals(expected_schema.with_metadata(expected_metadata), check_metadata=True) + + +def test_schema_validator_after_node_execution(tmp_path): + def foo(x: pd.DataFrame) -> pd.DataFrame: + """doc""" + return x + + version = graph_types.hash_source_code(foo, strip=True) + node = graph_types.HamiltonNode( + name=foo.__name__, + type=pd.DataFrame, + documentation=foo.__doc__, + tags={}, + is_external_input=False, + originating_functions=(foo,), + required_dependencies=set(), + optional_dependencies=set(), + ) + h_graph = graph_types.HamiltonGraph([node]) + df = pd.DataFrame({"a": [0, 1], "b": [True, False]}) + + expected_schema = pyarrow.schema( + [ + ("a", "int64"), + ("b", "bool"), + ] + ) + expected_metadata = { + b"name": foo.__name__.encode(), + b"documentation": foo.__doc__.encode(), + b"version": version.encode(), + } + + # set the HamiltonGraph on the state of the adaper + adapter = h_schema.SchemaValidator(schema_dir=tmp_path) + adapter.h_graph = h_graph + + adapter.run_after_node_execution(node_name=foo.__name__, result=df) + + tracked_schema = adapter.schemas[foo.__name__] + assert tracked_schema.equals( + expected_schema.with_metadata(expected_metadata), check_metadata=True + )