Skip to content

Commit

Permalink
Adds bytesIO/bytes input for to.file
Browse files Browse the repository at this point in the history
This helps when we want to save binary data
  • Loading branch information
elijahbenizzy committed Dec 28, 2023
1 parent 99c3ee1 commit c44bad7
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
29 changes: 28 additions & 1 deletion hamilton/io/default_data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dataclasses
import io
import json
import os
import pathlib
import pickle
from typing import Any, Collection, Dict, Tuple, Type
from typing import Any, Collection, Dict, Tuple, Type, Union

from hamilton.io.data_adapters import DataLoader, DataSaver
from hamilton.io.utils import get_file_metadata
Expand Down Expand Up @@ -80,6 +82,30 @@ def save_data(self, data: Any) -> Dict[str, Any]:
return get_file_metadata(self.path)


@dataclasses.dataclass
class RawFileDataSaverBytes(DataSaver):
path: Union[pathlib.Path, str]

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [bytes, io.BytesIO]

@classmethod
def name(cls) -> str:
return "file"

def save_data(self, data: Union[bytes, io.BytesIO]) -> Dict[str, Any]:
if isinstance(data, io.BytesIO):
data_bytes = data.getvalue() # Extract bytes from BytesIO
else:
data_bytes = data

with open(self.path, "wb") as file:
file.write(data_bytes)

return get_file_metadata(str(self.path))


@dataclasses.dataclass
class PickleLoader(DataLoader):
path: str
Expand Down Expand Up @@ -172,6 +198,7 @@ def name(cls) -> str:
LiteralValueDataLoader,
RawFileDataLoader,
RawFileDataSaver,
RawFileDataSaverBytes,
PickleLoader,
PickleSaver,
EnvVarDataLoader,
Expand Down
26 changes: 26 additions & 0 deletions tests/io/test_default_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import io
import pathlib

import pytest

from hamilton.io.default_data_loaders import RawFileDataSaverBytes


@pytest.mark.parametrize(
"data",
[
b"test",
io.BytesIO(b"test"),
],
)
def test_raw_file_adapter(data, tmp_path: pathlib.Path) -> None:
path = tmp_path / "test"

writer = RawFileDataSaverBytes(path=path)
writer.save_data(data)

with open(path, "rb") as f:
data2 = f.read()

data_processed = data if type(data) is bytes else data.getvalue()
assert data_processed == data2

0 comments on commit c44bad7

Please sign in to comment.