Skip to content

Commit

Permalink
Adds tests
Browse files Browse the repository at this point in the history
To catch case with `file` and without `path` or `file`.
  • Loading branch information
skrawcz committed Nov 24, 2024
1 parent 2e46872 commit e12b943
Showing 1 changed file with 107 additions and 0 deletions.
107 changes: 107 additions & 0 deletions tests/caching/test_result_store.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pathlib
import pickle
from typing import Any, Collection, Dict, Tuple, Type

import pytest

from hamilton.caching import fingerprinting
from hamilton.caching.stores.base import search_data_adapter_registry
from hamilton.caching.stores.file import FileResultStore
from hamilton.io.data_adapters import DataLoader, DataSaver


@pytest.fixture
Expand Down Expand Up @@ -114,3 +116,108 @@ def test_save_and_load_materializer(format, value, result_store):

assert materialized_path.exists()
assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value)


class FakeParquetSaver(DataSaver):
def __init__(self, file):
self.file = file

def save_data(self, data: Any) -> Dict[str, Any]:
with open(self.file, "w") as f:
f.write(str(data))
return {"meta": "data"}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


class FakeParquetLoader(DataLoader):
def __init__(self, file):
self.file = file

def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
with open(self.file, "r") as f:
data = eval(f.read())
return data, {"meta": data}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


def test_save_and_load_file_in_init(result_store):
value = {"a": 1}
saver_cls, loader_cls = (FakeParquetSaver, FakeParquetLoader)
data_version = "foo"
materialized_path = result_store._materialized_path(data_version, saver_cls)

result_store.set(
data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls
)
retrieved_value = result_store.get(data_version)

assert materialized_path.exists()
assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value)


class BadSaver(DataSaver):
def __init__(self, file123):
self.file = file123

def save_data(self, data: Any) -> Dict[str, Any]:
with open(self.file, "w") as f:
f.write(str(data))
return {"meta": "data"}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


class BadLoader(DataLoader):
def __init__(self, file123):
self.file = file123

def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
with open(self.file, "r") as f:
data = eval(f.read())
return data, {"meta": data}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


def test_save_and_load_not_path_not_file_init_error(result_store):
value = {"a": 1}
saver_cls, loader_cls = (BadSaver, BadLoader)
data_version = "foo"
with pytest.raises(ValueError):
result_store.set(
data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls
)
with pytest.raises(ValueError):
result_store.set( # make something store it in the result store
data_version=data_version,
result=value,
saver_cls=FakeParquetSaver,
loader_cls=loader_cls,
)
result_store.get(data_version)

0 comments on commit e12b943

Please sign in to comment.