Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #1240 #1241

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions hamilton/caching/stores/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import shutil
from pathlib import Path
from typing import Any, Optional
Expand Down Expand Up @@ -68,8 +69,24 @@ def set(
if saver_cls is not None:
# materialized_path
materialized_path = self._materialized_path(data_version, saver_cls)
saver = saver_cls(path=str(materialized_path.absolute()))
loader = loader_cls(path=str(materialized_path.absolute()))
saver_argspec = inspect.getfullargspec(saver_cls.__init__)
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
loader_argspec = inspect.getfullargspec(loader_cls.__init__)
if "file" in saver_argspec.args:
saver = saver_cls(file=str(materialized_path.absolute()))
elif "path" in saver_argspec.args:
saver = saver_cls(path=str(materialized_path.absolute()))
else:
raise ValueError(
f"Saver [{saver_cls.name()}] must have either `file` or `path` as an argument."
)
if "file" in loader_argspec.args:
loader = loader_cls(file=str(materialized_path.absolute()))
elif "path" in loader_argspec.args:
loader = loader_cls(path=str(materialized_path.absolute()))
else:
raise ValueError(
f"Loader [{loader_cls.name()}] must have either `file` or `path` as an argument."
)
else:
saver = None
loader = None
Expand Down
107 changes: 107 additions & 0 deletions tests/caching/test_result_store.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pathlib
import pickle
from typing import Any, Collection, Dict, Tuple, Type

import pytest

from hamilton.caching import fingerprinting
from hamilton.caching.stores.base import search_data_adapter_registry
from hamilton.caching.stores.file import FileResultStore
from hamilton.io.data_adapters import DataLoader, DataSaver


@pytest.fixture
Expand Down Expand Up @@ -114,3 +116,108 @@ def test_save_and_load_materializer(format, value, result_store):

assert materialized_path.exists()
assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value)


class FakeParquetSaver(DataSaver):
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, file):
self.file = file

def save_data(self, data: Any) -> Dict[str, Any]:
with open(self.file, "w") as f:
f.write(str(data))
return {"meta": "data"}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


class FakeParquetLoader(DataLoader):
def __init__(self, file):
self.file = file

def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
with open(self.file, "r") as f:
data = eval(f.read())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using eval to load data is a security risk. Consider using a safer alternative like json.loads or ast.literal_eval if the data format allows.

Suggested change
data = eval(f.read())
data = ast.literal_eval(f.read())

return data, {"meta": data}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


def test_save_and_load_file_in_init(result_store):
value = {"a": 1}
saver_cls, loader_cls = (FakeParquetSaver, FakeParquetLoader)
data_version = "foo"
materialized_path = result_store._materialized_path(data_version, saver_cls)

result_store.set(
data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls
)
retrieved_value = result_store.get(data_version)

assert materialized_path.exists()
assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value)


class BadSaver(DataSaver):
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, file123):
self.file = file123

def save_data(self, data: Any) -> Dict[str, Any]:
with open(self.file, "w") as f:
f.write(str(data))
return {"meta": "data"}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


class BadLoader(DataLoader):
def __init__(self, file123):
self.file = file123

def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
with open(self.file, "r") as f:
data = eval(f.read())
return data, {"meta": data}

@classmethod
def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
def name(cls) -> str:
return "fake_parquet"


def test_save_and_load_not_path_not_file_init_error(result_store):
value = {"a": 1}
saver_cls, loader_cls = (BadSaver, BadLoader)
data_version = "foo"
with pytest.raises(ValueError):
result_store.set(
data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls
)
with pytest.raises(ValueError):
result_store.set( # make something store it in the result store
data_version=data_version,
result=value,
saver_cls=FakeParquetSaver,
loader_cls=loader_cls,
)
result_store.get(data_version)