From 2e4687210be6b2f4e1a057de290582a2dcb2d7ce Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sat, 23 Nov 2024 17:01:12 -0800 Subject: [PATCH 1/2] Fixes #1240 The cache store assumed that every persister took a `path` argument. That is not the case because the savers / loaders wrap external APIs and we decided to not try to create our own abstraction layer around them, and instead mirror them. E.g. polars takes `file`, but pandas takes `path`. --- hamilton/caching/stores/file.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/hamilton/caching/stores/file.py b/hamilton/caching/stores/file.py index 34144d9c1..9721a7232 100644 --- a/hamilton/caching/stores/file.py +++ b/hamilton/caching/stores/file.py @@ -1,3 +1,4 @@ +import inspect import shutil from pathlib import Path from typing import Any, Optional @@ -68,8 +69,24 @@ def set( if saver_cls is not None: # materialized_path materialized_path = self._materialized_path(data_version, saver_cls) - saver = saver_cls(path=str(materialized_path.absolute())) - loader = loader_cls(path=str(materialized_path.absolute())) + saver_argspec = inspect.getfullargspec(saver_cls.__init__) + loader_argspec = inspect.getfullargspec(loader_cls.__init__) + if "file" in saver_argspec.args: + saver = saver_cls(file=str(materialized_path.absolute())) + elif "path" in saver_argspec.args: + saver = saver_cls(path=str(materialized_path.absolute())) + else: + raise ValueError( + f"Saver [{saver_cls.name()}] must have either `file` or `path` as an argument." + ) + if "file" in loader_argspec.args: + loader = loader_cls(file=str(materialized_path.absolute())) + elif "path" in loader_argspec.args: + loader = loader_cls(path=str(materialized_path.absolute())) + else: + raise ValueError( + f"Loader [{loader_cls.name()}] must have either `file` or `path` as an argument." + ) else: saver = None loader = None From e12b943b4a3c1f3d5db5d0a957adeb603f61a937 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sat, 23 Nov 2024 20:43:50 -0800 Subject: [PATCH 2/2] Adds tests To catch case with `file` and without `path` or `file`. --- tests/caching/test_result_store.py | 107 +++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/tests/caching/test_result_store.py b/tests/caching/test_result_store.py index 9afbd455a..ab94840ea 100644 --- a/tests/caching/test_result_store.py +++ b/tests/caching/test_result_store.py @@ -1,11 +1,13 @@ import pathlib import pickle +from typing import Any, Collection, Dict, Tuple, Type import pytest from hamilton.caching import fingerprinting from hamilton.caching.stores.base import search_data_adapter_registry from hamilton.caching.stores.file import FileResultStore +from hamilton.io.data_adapters import DataLoader, DataSaver @pytest.fixture @@ -114,3 +116,108 @@ def test_save_and_load_materializer(format, value, result_store): assert materialized_path.exists() assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value) + + +class FakeParquetSaver(DataSaver): + def __init__(self, file): + self.file = file + + def save_data(self, data: Any) -> Dict[str, Any]: + with open(self.file, "w") as f: + f.write(str(data)) + return {"meta": "data"} + + @classmethod + def applicable_types(cls) -> Collection[Type]: + pass + + @classmethod + def name(cls) -> str: + return "fake_parquet" + + +class FakeParquetLoader(DataLoader): + def __init__(self, file): + self.file = file + + def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]: + with open(self.file, "r") as f: + data = eval(f.read()) + return data, {"meta": data} + + @classmethod + def applicable_types(cls) -> Collection[Type]: + pass + + @classmethod + def name(cls) -> str: + return "fake_parquet" + + +def test_save_and_load_file_in_init(result_store): + value = {"a": 1} + saver_cls, loader_cls = (FakeParquetSaver, FakeParquetLoader) + data_version = "foo" + materialized_path = result_store._materialized_path(data_version, saver_cls) + + result_store.set( + data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls + ) + retrieved_value = result_store.get(data_version) + + assert materialized_path.exists() + assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value) + + +class BadSaver(DataSaver): + def __init__(self, file123): + self.file = file123 + + def save_data(self, data: Any) -> Dict[str, Any]: + with open(self.file, "w") as f: + f.write(str(data)) + return {"meta": "data"} + + @classmethod + def applicable_types(cls) -> Collection[Type]: + pass + + @classmethod + def name(cls) -> str: + return "fake_parquet" + + +class BadLoader(DataLoader): + def __init__(self, file123): + self.file = file123 + + def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]: + with open(self.file, "r") as f: + data = eval(f.read()) + return data, {"meta": data} + + @classmethod + def applicable_types(cls) -> Collection[Type]: + pass + + @classmethod + def name(cls) -> str: + return "fake_parquet" + + +def test_save_and_load_not_path_not_file_init_error(result_store): + value = {"a": 1} + saver_cls, loader_cls = (BadSaver, BadLoader) + data_version = "foo" + with pytest.raises(ValueError): + result_store.set( + data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls + ) + with pytest.raises(ValueError): + result_store.set( # make something store it in the result store + data_version=data_version, + result=value, + saver_cls=FakeParquetSaver, + loader_cls=loader_cls, + ) + result_store.get(data_version)