diff --git a/hamilton/io/default_data_loaders.py b/hamilton/io/default_data_loaders.py index e853b885d..870d83a62 100644 --- a/hamilton/io/default_data_loaders.py +++ b/hamilton/io/default_data_loaders.py @@ -4,7 +4,7 @@ import os import pathlib import pickle -from typing import Any, Collection, Dict, List, Tuple, Type, Union +from typing import Any, Collection, Dict, Tuple, Type, Union from hamilton.io.data_adapters import DataLoader, DataSaver from hamilton.io.utils import get_file_metadata @@ -16,7 +16,7 @@ class JSONDataLoader(DataLoader): @classmethod def applicable_types(cls) -> Collection[Type]: - return [dict, List[dict]] + return [dict, list] def load_data(self, type_: Type) -> Tuple[dict, Dict[str, Any]]: with open(self.path, "r") as f: @@ -33,7 +33,7 @@ class JSONDataSaver(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [dict, List[dict]] + return [dict, list] @classmethod def name(cls) -> str: diff --git a/tests/io/test_default_adapters.py b/tests/io/test_default_adapters.py index fa46052d7..0c7012ac0 100644 --- a/tests/io/test_default_adapters.py +++ b/tests/io/test_default_adapters.py @@ -1,7 +1,6 @@ import io import json import pathlib -from typing import List import pytest @@ -29,10 +28,12 @@ def test_raw_file_adapter(data, tmp_path: pathlib.Path) -> None: assert data_processed == data2 -@pytest.mark.parametrize("data", [{"key": "value"}, [{"key": "value1"}, {"key": "value2"}]]) +@pytest.mark.parametrize( + "data", [{"key": "value"}, [{"key": "value1"}, {"key": "value2"}], ["value1", "value2"], [0, 1]] +) def test_json_save_object_and_array(data, tmp_path: pathlib.Path): """Test that `from_.json` and `to.json` can handle JSON objects where - the top-level is an object `{ }` -> dict or an array `[ ]` -> list[dict] + the top-level is an object `{ }` -> dict or an array `[ ]` -> list """ data_path = tmp_path / "data.json" saver = JSONDataSaver(path=data_path) @@ -40,16 +41,18 @@ def test_json_save_object_and_array(data, tmp_path: pathlib.Path): metadata = saver.save_data(data) loaded_data = json.loads(data_path.read_text()) - assert JSONDataSaver.applicable_types() == [dict, List[dict]] + assert JSONDataSaver.applicable_types() == [dict, list] assert data_path.exists() assert metadata[FILE_METADATA]["path"] == str(data_path) assert data == loaded_data -@pytest.mark.parametrize("data", [{"key": "value"}, [{"key": "value1"}, {"key": "value2"}]]) +@pytest.mark.parametrize( + "data", [{"key": "value"}, [{"key": "value1"}, {"key": "value2"}], ["value1", "value2"], [0, 1]] +) def test_json_load_object_and_array(data, tmp_path: pathlib.Path): """Test that `from_.json` and `to.json` can handle JSON objects where - the top-level is an object `{ }` -> dict or an array `[ ]` -> list[dict] + the top-level is an object `{ }` -> dict or an array `[ ]` -> list """ data_path = tmp_path / "data.json" loader = JSONDataLoader(path=data_path) @@ -57,5 +60,5 @@ def test_json_load_object_and_array(data, tmp_path: pathlib.Path): json.dump(data, data_path.open("w")) loaded_data, metadata = loader.load_data(type(data)) - assert JSONDataLoader.applicable_types() == [dict, List[dict]] + assert JSONDataLoader.applicable_types() == [dict, list] assert data == loaded_data