From e12b943b4a3c1f3d5db5d0a957adeb603f61a937 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sat, 23 Nov 2024 20:43:50 -0800 Subject: [PATCH] Adds tests To catch case with `file` and without `path` or `file`. --- tests/caching/test_result_store.py | 107 +++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/tests/caching/test_result_store.py b/tests/caching/test_result_store.py index 9afbd455a..ab94840ea 100644 --- a/tests/caching/test_result_store.py +++ b/tests/caching/test_result_store.py @@ -1,11 +1,13 @@ import pathlib import pickle +from typing import Any, Collection, Dict, Tuple, Type import pytest from hamilton.caching import fingerprinting from hamilton.caching.stores.base import search_data_adapter_registry from hamilton.caching.stores.file import FileResultStore +from hamilton.io.data_adapters import DataLoader, DataSaver @pytest.fixture @@ -114,3 +116,108 @@ def test_save_and_load_materializer(format, value, result_store): assert materialized_path.exists() assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value) + + +class FakeParquetSaver(DataSaver): + def __init__(self, file): + self.file = file + + def save_data(self, data: Any) -> Dict[str, Any]: + with open(self.file, "w") as f: + f.write(str(data)) + return {"meta": "data"} + + @classmethod + def applicable_types(cls) -> Collection[Type]: + pass + + @classmethod + def name(cls) -> str: + return "fake_parquet" + + +class FakeParquetLoader(DataLoader): + def __init__(self, file): + self.file = file + + def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]: + with open(self.file, "r") as f: + data = eval(f.read()) + return data, {"meta": data} + + @classmethod + def applicable_types(cls) -> Collection[Type]: + pass + + @classmethod + def name(cls) -> str: + return "fake_parquet" + + +def test_save_and_load_file_in_init(result_store): + value = {"a": 1} + saver_cls, loader_cls = (FakeParquetSaver, FakeParquetLoader) + data_version = "foo" + materialized_path = result_store._materialized_path(data_version, saver_cls) + + result_store.set( + data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls + ) + retrieved_value = result_store.get(data_version) + + assert materialized_path.exists() + assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value) + + +class BadSaver(DataSaver): + def __init__(self, file123): + self.file = file123 + + def save_data(self, data: Any) -> Dict[str, Any]: + with open(self.file, "w") as f: + f.write(str(data)) + return {"meta": "data"} + + @classmethod + def applicable_types(cls) -> Collection[Type]: + pass + + @classmethod + def name(cls) -> str: + return "fake_parquet" + + +class BadLoader(DataLoader): + def __init__(self, file123): + self.file = file123 + + def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]: + with open(self.file, "r") as f: + data = eval(f.read()) + return data, {"meta": data} + + @classmethod + def applicable_types(cls) -> Collection[Type]: + pass + + @classmethod + def name(cls) -> str: + return "fake_parquet" + + +def test_save_and_load_not_path_not_file_init_error(result_store): + value = {"a": 1} + saver_cls, loader_cls = (BadSaver, BadLoader) + data_version = "foo" + with pytest.raises(ValueError): + result_store.set( + data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls + ) + with pytest.raises(ValueError): + result_store.set( # make something store it in the result store + data_version=data_version, + result=value, + saver_cls=FakeParquetSaver, + loader_cls=loader_cls, + ) + result_store.get(data_version)