Skip to content

Commit

Permalink
Fixes #1240 (#1241)
Browse files Browse the repository at this point in the history
* Fixes #1240

The cache store assumed that every persister took a `path` argument. That is
not the case because the savers / loaders wrap external APIs and we decided
to not try to create our own abstraction layer around them, and instead mirror them.

E.g. polars takes `file`, but pandas takes `path`.

This means future changes could need to change things here.

* Adds tests

To catch case with `file` and without `path` or `file`.
  • Loading branch information
skrawcz authored Nov 26, 2024
1 parent 47a5146 commit 911c3ca
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 2 deletions.
21 changes: 19 additions & 2 deletions hamilton/caching/stores/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import shutil
from pathlib import Path
from typing import Any, Optional
Expand Down Expand Up @@ -68,8 +69,24 @@ def set(
if saver_cls is not None:
# materialized_path
materialized_path = self._materialized_path(data_version, saver_cls)
saver = saver_cls(path=str(materialized_path.absolute()))
loader = loader_cls(path=str(materialized_path.absolute()))
saver_argspec = inspect.getfullargspec(saver_cls.__init__)
loader_argspec = inspect.getfullargspec(loader_cls.__init__)
if "file" in saver_argspec.args:
saver = saver_cls(file=str(materialized_path.absolute()))
elif "path" in saver_argspec.args:
saver = saver_cls(path=str(materialized_path.absolute()))
else:
raise ValueError(
f"Saver [{saver_cls.name()}] must have either `file` or `path` as an argument."
)
if "file" in loader_argspec.args:
loader = loader_cls(file=str(materialized_path.absolute()))
elif "path" in loader_argspec.args:
loader = loader_cls(path=str(materialized_path.absolute()))
else:
raise ValueError(
f"Loader [{loader_cls.name()}] must have either `file` or `path` as an argument."
)
else:
saver = None
loader = None
Expand Down
107 changes: 107 additions & 0 deletions tests/caching/test_result_store.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pathlib
import pickle
from typing import Any, Collection, Dict, Tuple, Type

import pytest

from hamilton.caching import fingerprinting
from hamilton.caching.stores.base import search_data_adapter_registry
from hamilton.caching.stores.file import FileResultStore
from hamilton.io.data_adapters import DataLoader, DataSaver


@pytest.fixture
Expand Down Expand Up @@ -114,3 +116,108 @@ def test_save_and_load_materializer(format, value, result_store):

assert materialized_path.exists()
assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value)


class FakeParquetSaver(DataSaver):
def __init__(self, file):
self.file = file

def save_data(self, data: Any) -> Dict[str, Any]:
with open(self.file, "w") as f:
f.write(str(data))
return {"meta": "data"}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


class FakeParquetLoader(DataLoader):
def __init__(self, file):
self.file = file

def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
with open(self.file, "r") as f:
data = eval(f.read())
return data, {"meta": data}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


def test_save_and_load_file_in_init(result_store):
value = {"a": 1}
saver_cls, loader_cls = (FakeParquetSaver, FakeParquetLoader)
data_version = "foo"
materialized_path = result_store._materialized_path(data_version, saver_cls)

result_store.set(
data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls
)
retrieved_value = result_store.get(data_version)

assert materialized_path.exists()
assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value)


class BadSaver(DataSaver):
def __init__(self, file123):
self.file = file123

def save_data(self, data: Any) -> Dict[str, Any]:
with open(self.file, "w") as f:
f.write(str(data))
return {"meta": "data"}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


class BadLoader(DataLoader):
def __init__(self, file123):
self.file = file123

def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
with open(self.file, "r") as f:
data = eval(f.read())
return data, {"meta": data}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


def test_save_and_load_not_path_not_file_init_error(result_store):
value = {"a": 1}
saver_cls, loader_cls = (BadSaver, BadLoader)
data_version = "foo"
with pytest.raises(ValueError):
result_store.set(
data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls
)
with pytest.raises(ValueError):
result_store.set( # make something store it in the result store
data_version=data_version,
result=value,
saver_cls=FakeParquetSaver,
loader_cls=loader_cls,
)
result_store.get(data_version)

0 comments on commit 911c3ca

Please sign in to comment.