Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added tests for databackends and schema adapter #973

Merged
merged 1 commit into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions hamilton/experimental/h_databackends.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _(df: h_databackends.AbstractIbisDataFrame) -> pyarrow.Schema:

import importlib
import inspect
from typing import Tuple, Type, Union
from typing import Tuple

from hamilton.experimental.databackend import AbstractBackend

Expand Down Expand Up @@ -128,7 +128,7 @@ class AbstractModinDataFrame(AbstractBackend):
_backends = [("modin.pandas", "DataFrame")]


def register_backends() -> Tuple[Type, Type]:
def register_backends() -> Tuple[Tuple[type], Tuple[type]]:
"""Register databackends defined in this module that
include `DataFrame` and `Column` in their class name
"""
Expand All @@ -143,8 +143,8 @@ def register_backends() -> Tuple[Type, Type]:
abstract_column_types.add(cls)

# Union[tuple()] creates a Union type object
DATAFRAME_TYPES = Union[tuple(abstract_dataframe_types)]
COLUMN_TYPES = Union[tuple(abstract_column_types)]
DATAFRAME_TYPES = tuple(abstract_dataframe_types)
COLUMN_TYPES = tuple(abstract_column_types)
return DATAFRAME_TYPES, COLUMN_TYPES


Expand Down
4 changes: 3 additions & 1 deletion hamilton/plugins/h_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def _(df: h_databackends.AbstractIbisDataFrame, **kwargs) -> pyarrow.Schema:
# ongoing polars discussion: https://github.com/pola-rs/polars/issues/15600


def get_dataframe_schema(df: h_databackends.DATAFRAME_TYPES, node: HamiltonNode) -> pyarrow.Schema:
def get_dataframe_schema(
df: Union[h_databackends.DATAFRAME_TYPES], node: HamiltonNode
) -> pyarrow.Schema:
"""Get pyarrow schema of a node result and store node metadata on the pyarrow schema."""
schema = _get_arrow_schema(df)
metadata = dict(
Expand Down
Empty file added tests/experimental/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions tests/experimental/test_h_databackends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pandas as pd

from hamilton.experimental import h_databackends


def test_isinstance_dataframe():
value = pd.DataFrame()
assert isinstance(value, h_databackends.DATAFRAME_TYPES)


def test_issubclass_dataframe():
class_ = pd.DataFrame
assert issubclass(class_, h_databackends.DATAFRAME_TYPES)


def test_not_isinstance_dataframe():
value = 6
assert not isinstance(value, h_databackends.DATAFRAME_TYPES)


def test_not_issubclass_dataframe():
class_ = int
assert not issubclass(class_, h_databackends.DATAFRAME_TYPES)


def test_isinstance_column():
value = pd.Series()
assert isinstance(value, h_databackends.COLUMN_TYPES)


def test_issubclass_column():
class_ = pd.Series
assert issubclass(class_, h_databackends.COLUMN_TYPES)


def test_not_isinstance_column():
value = 6
assert not isinstance(value, h_databackends.COLUMN_TYPES)


def test_not_issubclass_column():
class_ = int
assert not issubclass(class_, h_databackends.COLUMN_TYPES)
81 changes: 80 additions & 1 deletion tests/plugins/test_h_schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import pathlib

import pandas as pd
import pyarrow
import pytest

from hamilton import graph_types
from hamilton.plugins import h_schema


Expand Down Expand Up @@ -140,7 +142,6 @@ def test_schema_edited_schema_metadata(schema1: pyarrow.Schema, metadata1: dict,
assert schema_diff[h_schema.SCHEMA_METADATA_FIELD].value["key"].diff == h_schema.Diff.UNEQUAL

human_readable_diff = h_schema.human_readable_diff(schema_diff)
print(human_readable_diff)

assert human_readable_diff == {
h_schema.SCHEMA_METADATA_FIELD: {"key": {"cur": "value1", "ref": "value2"}}
Expand Down Expand Up @@ -227,3 +228,81 @@ def test_save_schema_to_disk(schema1: pyarrow.Schema, tmp_path: pathlib.Path):
h_schema.save_schema(path=schema_path, schema=schema1)
loaded_schema = pyarrow.ipc.read_schema(schema_path)
assert schema1.equals(loaded_schema)


def test_get_dataframe_schema():
def foo(x: pd.DataFrame) -> pd.DataFrame:
"""doc"""
return x

version = graph_types.hash_source_code(foo, strip=True)
node = graph_types.HamiltonNode(
name=foo.__name__,
type=pd.DataFrame,
documentation=foo.__doc__,
tags={},
is_external_input=False,
originating_functions=(foo,),
required_dependencies=set(),
optional_dependencies=set(),
)
df = pd.DataFrame({"a": [0, 1], "b": [True, False]})

expected_schema = pyarrow.schema(
[
("a", "int64"),
("b", "bool"),
]
)
expected_metadata = {
b"name": foo.__name__.encode(),
b"documentation": foo.__doc__.encode(),
b"version": version.encode(),
}

schema = h_schema.get_dataframe_schema(df, node)

assert schema.equals(expected_schema.with_metadata(expected_metadata), check_metadata=True)


def test_schema_validator_after_node_execution(tmp_path):
def foo(x: pd.DataFrame) -> pd.DataFrame:
"""doc"""
return x

version = graph_types.hash_source_code(foo, strip=True)
node = graph_types.HamiltonNode(
name=foo.__name__,
type=pd.DataFrame,
documentation=foo.__doc__,
tags={},
is_external_input=False,
originating_functions=(foo,),
required_dependencies=set(),
optional_dependencies=set(),
)
h_graph = graph_types.HamiltonGraph([node])
df = pd.DataFrame({"a": [0, 1], "b": [True, False]})

expected_schema = pyarrow.schema(
[
("a", "int64"),
("b", "bool"),
]
)
expected_metadata = {
b"name": foo.__name__.encode(),
b"documentation": foo.__doc__.encode(),
b"version": version.encode(),
}

# set the HamiltonGraph on the state of the adaper
adapter = h_schema.SchemaValidator(schema_dir=tmp_path)
adapter.h_graph = h_graph

adapter.run_after_node_execution(node_name=foo.__name__, result=df)

tracked_schema = adapter.schemas[foo.__name__]
assert tracked_schema.equals(
expected_schema.with_metadata(expected_metadata), check_metadata=True
)