diff --git a/hamilton/io/default_data_loaders.py b/hamilton/io/default_data_loaders.py index 88616c3b3..98645cd12 100644 --- a/hamilton/io/default_data_loaders.py +++ b/hamilton/io/default_data_loaders.py @@ -1,8 +1,10 @@ import dataclasses +import io import json import os +import pathlib import pickle -from typing import Any, Collection, Dict, Tuple, Type +from typing import Any, Collection, Dict, Tuple, Type, Union from hamilton.io.data_adapters import DataLoader, DataSaver from hamilton.io.utils import get_file_metadata @@ -80,6 +82,30 @@ def save_data(self, data: Any) -> Dict[str, Any]: return get_file_metadata(self.path) +@dataclasses.dataclass +class RawFileDataSaverBytes(DataSaver): + path: Union[pathlib.Path, str] + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [bytes, io.BytesIO] + + @classmethod + def name(cls) -> str: + return "file" + + def save_data(self, data: Union[bytes, io.BytesIO]) -> Dict[str, Any]: + if isinstance(data, io.BytesIO): + data_bytes = data.getvalue() # Extract bytes from BytesIO + else: + data_bytes = data + + with open(self.path, "wb") as file: + file.write(data_bytes) + + return get_file_metadata(str(self.path)) + + @dataclasses.dataclass class PickleLoader(DataLoader): path: str @@ -172,6 +198,7 @@ def name(cls) -> str: LiteralValueDataLoader, RawFileDataLoader, RawFileDataSaver, + RawFileDataSaverBytes, PickleLoader, PickleSaver, EnvVarDataLoader, diff --git a/tests/io/test_default_adapters.py b/tests/io/test_default_adapters.py new file mode 100644 index 000000000..f9b079eaa --- /dev/null +++ b/tests/io/test_default_adapters.py @@ -0,0 +1,26 @@ +import io +import pathlib + +import pytest + +from hamilton.io.default_data_loaders import RawFileDataSaverBytes + + +@pytest.mark.parametrize( + "data", + [ + b"test", + io.BytesIO(b"test"), + ], +) +def test_raw_file_adapter(data, tmp_path: pathlib.Path) -> None: + path = tmp_path / "test" + + writer = RawFileDataSaverBytes(path=path) + writer.save_data(data) + + with open(path, "rb") as f: + data2 = f.read() + + data_processed = data if type(data) is bytes else data.getvalue() + assert data_processed == data2