Skip to content

Commit

Permalink
Adds tests/fixes up pandas data loaders + savers
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Apr 4, 2023
1 parent 070b90b commit dde909f
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 29 deletions.
7 changes: 4 additions & 3 deletions hamilton/function_modifiers/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,10 @@ def __getattr__(cls, item: str):
except AttributeError as e:
raise AttributeError(
f"No saver named: {item} available for {cls.__name__}. "
f"Available data savers are: {SAVER_REGISTRY.keys()}. "
f"Available data savers are: {list(SAVER_REGISTRY.keys())}. "
f"If you've gotten to this point, you either (1) spelled the "
f"loader name wrong, (2) are trying to use a saver that does"
f"not exist (yet)"
f"not exist (yet)."
) from e


Expand All @@ -382,6 +382,7 @@ def __init__(
artifact_name_: str = None,
**kwargs: ParametrizedDependency,
):
super(SaveToDecorator, self).__init__()
self.artifact_name = artifact_name_
self.saver_classes = saver_classes_
self.kwargs = kwargs
Expand Down Expand Up @@ -464,7 +465,7 @@ def validate(self, fn: Callable):
pass


class save_to(metaclass=load_from__meta__):
class save_to(metaclass=save_to__meta__):
"""Decorator that outputs data to some external source. You can think
about this as the inverse of load_from.
Expand Down
4 changes: 2 additions & 2 deletions hamilton/io/data_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class AdapterCommon(abc.ABC):
@classmethod
@abc.abstractmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
"""Returns the types that this data loader can load to.
These will be checked against the desired type to determine
whether this is a suitable loader for that type.
Expand Down Expand Up @@ -37,7 +37,7 @@ def applies_to(cls, type_: Type[Type]) -> bool:
:param type_: Candidate type
:return: True if this data loader can load to the type, False otherwise.
"""
for load_to in cls.load_targets():
for load_to in cls.applicable_types():
if custom_subclass_check(load_to, type_):
return True
return False
Expand Down
10 changes: 5 additions & 5 deletions hamilton/io/default_data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class JSONDataAdapter(DataLoader, DataSaver):
path: str

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [dict]

def load_data(self, type_: Type) -> Tuple[dict, Dict[str, Any]]:
Expand All @@ -40,7 +40,7 @@ def load_data(self, type_: Type) -> Tuple[str, Dict[str, Any]]:
return f.read(), get_file_loading_metadata(self.path)

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [str]

@classmethod
Expand All @@ -58,7 +58,7 @@ class PickleLoader(DataLoader):
path: str

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [object]

@classmethod
Expand Down Expand Up @@ -87,7 +87,7 @@ def name(cls) -> str:
return "environment"

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [dict]


Expand All @@ -96,7 +96,7 @@ class LiteralValueDataLoader(DataLoader):
value: Any

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [Any]

def load_data(self, type_: Type) -> Tuple[dict, Dict[str, Any]]:
Expand Down
44 changes: 36 additions & 8 deletions hamilton/plugins/pandas_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Collection, Dict, Tuple, Type

from hamilton.io import utils
from hamilton.io.data_adapters import DataLoader
from hamilton.io.data_adapters import DataLoader, DataSaver

try:
import pandas as pd
Expand Down Expand Up @@ -35,27 +35,47 @@ def register_types():
register_types()


class DataFrameDataLoader(DataLoader, abc.ABC):
"""Base class for data loaders that load pandas dataframes."""
class DataFrameDataLoader(DataLoader, DataSaver, abc.ABC):
"""Base class for data loaders that saves/loads pandas dataframes.
Note that these are currently grouped together, but this could change!
We can change this as these are not part of the publicly exposed APIs.
Rather, the fixed component is the keys (E.G. csv, feather, etc...) , which,
when combined with types, correspond to a group of specific parameter. As such,
the backwards-compatible invariance enables us to change the implementation
(which classes), and so long as the set of parameters/load targets are compatible,
we are good to go."""

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [DATAFRAME_TYPE]

@abc.abstractmethod
def load_data(self, type_: Type[DATAFRAME_TYPE]) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]:
pass

@abc.abstractmethod
def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
pass


@dataclasses.dataclass
class CSVDataLoader(DataFrameDataLoader):
class CSVDataAdapter(DataFrameDataLoader):
"""Data loader for CSV files. Note that this currently does not support the wide array of
data loading functionality that pandas does. We will be adding this in over time, but for now
you can subclass this or open up an issue if this doesn't have what you want."""
you can subclass this or open up an issue if this doesn't have what you want.
Note that, when saving, this does not currently save the index.
We'll likely want to enable this in the future as an optional subclass,
in which case we'll separate it out.
"""

path: str

def load_data(self, type_: Type) -> Tuple[pd.DataFrame, Dict[str, Any]]:
def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
data.to_csv(self.path, index=False)
return utils.get_file_loading_metadata(self.path)

def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]:
df = pd.read_csv(self.path)
metadata = utils.get_file_loading_metadata(self.path)
return df, metadata
Expand All @@ -73,6 +93,10 @@ class FeatherDataLoader(DataFrameDataLoader):

path: str

def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
data.to_feather(self.path)
return utils.get_file_loading_metadata(self.path)

def load_data(self, type_: Type[DATAFRAME_TYPE]) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]:
df = pd.read_feather(self.path)
metadata = utils.get_file_loading_metadata(self.path)
Expand All @@ -96,14 +120,18 @@ def load_data(self, type_: Type[DATAFRAME_TYPE]) -> Tuple[DATAFRAME_TYPE, Dict[s
metadata = utils.get_file_loading_metadata(self.path)
return df, metadata

def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
data.to_parquet(self.path)
return utils.get_file_loading_metadata(self.path)

@classmethod
def name(cls) -> str:
return "parquet"


def register_data_loaders():
"""Function to register the data loaders for this extension."""
for loader in [CSVDataLoader, FeatherDataLoader, ParquetDataLoader]:
for loader in [CSVDataAdapter, FeatherDataLoader, ParquetDataLoader]:
registry.register_adapter(loader)


Expand Down
33 changes: 23 additions & 10 deletions tests/function_modifiers/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from hamilton import ad_hoc_utils, base, driver, graph, node
from hamilton.function_modifiers import base as fm_base
from hamilton.function_modifiers import source, value
from hamilton.function_modifiers import save_to, source, value
from hamilton.function_modifiers.adapters import (
LoadFromDecorator,
SaveToDecorator,
Expand Down Expand Up @@ -47,7 +47,7 @@ class MockDataLoader(DataLoader):
default_param_3: str = "6"

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [int]

def load_data(self, type_: Type[int]) -> Tuple[int, Dict[str, Any]]:
Expand Down Expand Up @@ -158,7 +158,7 @@ def name(cls) -> str:
return "dummy"

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [str]


Expand All @@ -172,14 +172,14 @@ def name(cls) -> str:
return "dummy"

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [int]


@dataclasses.dataclass
class IntDataLoader2(DataLoader):
@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [int]

def load_data(self, type_: Type) -> Tuple[int, Dict[str, Any]]:
Expand Down Expand Up @@ -344,8 +344,12 @@ def test_loader_fails_for_missing_attribute():
load_from.not_a_loader(param=value("foo"))


def test_pandas_extensions():
@load_from.csv(path=value("tests/resources/data/test_load_from_data.csv"))
def test_pandas_extensions_end_to_end(tmp_path_factory):
output_path = tmp_path_factory.mktemp("test_pandas_extensions_end_to_end") / "output.csv"
input_path = "tests/resources/data/test_load_from_data.csv"

@save_to.csv(path=source("output_path"), artifact_name_="save_df")
@load_from.csv(path=source("input_path"))
def df(data: pd.DataFrame) -> pd.DataFrame:
return data

Expand All @@ -355,13 +359,22 @@ def df(data: pd.DataFrame) -> pd.DataFrame:
ad_hoc_utils.create_temporary_module(df),
adapter=base.SimplePythonGraphAdapter(base.DictResult()),
)
# run once to check that loading is correct
result = dr.execute(
["df"],
inputs={"test_data": "tests/resources/data/test_load_from_data.csv"},
["df", "save_df"],
inputs={"input_path": input_path, "output_path": output_path},
)
assert result["df"].shape == (3, 5)
assert result["df"].loc[0, "firstName"] == "John"

#
result_just_read = dr.execute(
["df"],
inputs={"input_path": output_path},
)
# This is just reading the same file we wrote out, so it should be the same
pd.testing.assert_frame_equal(result["df"], result_just_read["df"])


@dataclasses.dataclass
class MarkingSaver(DataSaver):
Expand All @@ -374,7 +387,7 @@ def save_data(self, data: int) -> Dict[str, Any]:
return {}

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [int]

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/materialization/test_data_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class MockDataLoader(DataLoader):
default_param: int = 1

@classmethod
def load_targets(cls) -> Collection[Type]:
def applicable_types(cls) -> Collection[Type]:
return [bool]

def load_data(self, type_: Type) -> Tuple[int, Dict[str, Any]]:
Expand Down

0 comments on commit dde909f

Please sign in to comment.