From f4eafcaf36ddbb4fffe643553aff055f53c39a5d Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Sat, 1 Apr 2023 15:02:52 -0700 Subject: [PATCH] Adds implementation of data loaders See https://github.com/DAGWorks-Inc/hamilton/issues/46 for some initial context. Basic design is: 1. A decorator that loads data and injects into a function 2. A set of Data Loaders that can be constructed from kwargs (they're all dataclasses) 3. A LoadFromDecorator that takes in a variety of data loader classes that are possible to choose from 4. A registry that specifies a mapping of key -> data loader classes, allowing us to polymorphically select them More todos, up next 1. Add tests for the default loaders 2. Add safeguards to ensure any loader is a dataclass 3. Add documentation We also add a register_adapter function in the registry. Note that this just moves the adapters we currently have around, and does not add new ones. --- hamilton/function_modifiers/adapters.py | 258 +++++++++++++++ hamilton/io/__init__.py | 15 + hamilton/io/data_loaders.py | 92 ++++++ hamilton/io/default_data_loaders.py | 119 +++++++ hamilton/registry.py | 13 + tests/function_modifiers/test_adapters.py | 299 ++++++++++++++++++ tests/materialization/test_data_loaders.py | 26 ++ tests/resources/data/test_load_from_data.json | 25 ++ 8 files changed, 847 insertions(+) create mode 100644 hamilton/function_modifiers/adapters.py create mode 100644 hamilton/io/__init__.py create mode 100644 hamilton/io/data_loaders.py create mode 100644 hamilton/io/default_data_loaders.py create mode 100644 tests/function_modifiers/test_adapters.py create mode 100644 tests/materialization/test_data_loaders.py create mode 100644 tests/resources/data/test_load_from_data.json diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py new file mode 100644 index 000000000..3c2a08233 --- /dev/null +++ b/hamilton/function_modifiers/adapters.py @@ -0,0 +1,258 @@ +import inspect +import typing +from typing import Any, Callable, Dict, List, Tuple, Type + +from hamilton import node +from hamilton.function_modifiers.base import InvalidDecoratorException, NodeCreator +from hamilton.function_modifiers.dependencies import ( + LiteralDependency, + ParametrizedDependency, + UpstreamDependency, +) +from hamilton.io.data_loaders import DataLoader +from hamilton.node import DependencyType +from hamilton.registry import ADAPTER_REGISTRY + + +class LoaderFactory: + """Factory for data loaders. This handles the fact that we pass in source(...) and value(...) + parameters to the data loaders.""" + + def __init__(self, loader_cls: Type[DataLoader], **kwargs: ParametrizedDependency): + """Initializes a loader factory. This takes in parameterized dependencies + and stores them for later resolution. + + Note that this is not strictly necessary -- we could easily put this in the + + :param loader_cls: Class of the loader to create. + :param kwargs: Keyword arguments to pass to the loader, as parameterized dependencies. + """ + self.loader_cls = loader_cls + self.kwargs = kwargs + self.validate() + + def validate(self): + """Validates that the loader class has the required arguments, and that + the arguments passed in are valid. + + :raises InvalidDecoratorException: If the arguments are invalid. + """ + required_args = self.loader_cls.get_required_arguments() + optional_args = self.loader_cls.get_optional_arguments() + missing_params = set(required_args.keys()) - set(self.kwargs.keys()) + extra_params = ( + set(self.kwargs.keys()) - set(required_args.keys()) - set(optional_args.keys()) + ) + if len(missing_params) > 0: + raise InvalidDecoratorException( + f"Missing required parameters for loader : {self.loader_cls}: {missing_params}" + ) + if len(extra_params) > 0: + raise InvalidDecoratorException( + f"Extra parameters for loader: {self.loader_cls} {extra_params}" + ) + + def create_loader(self, **resolved_kwargs: Any) -> DataLoader: + return self.loader_cls(**resolved_kwargs) + + +class LoadFromDecorator(NodeCreator): + def __init__( + self, + loader_classes: typing.Sequence[Type[DataLoader]], + inject_=None, + **kwargs: ParametrizedDependency, + ): + """Instantiates a load_from decorator. This decorator will load from a data source, + and + + :param inject: The name of the parameter to inject the data into. + :param loader_cls: The data loader class to use. + :param kwargs: The arguments to pass to the data loader. + """ + self.loader_classes = loader_classes + self.kwargs = kwargs + self.inject = inject_ + + def resolve_kwargs(self) -> Tuple[Dict[str, str], Dict[str, Any]]: + """Resolves kwargs to a list of dependencies, and a dictionary of name + to resolved literal values. + + :return: A tuple of the dependencies, and the resolved literal kwargs. + """ + dependencies = {} + resolved_kwargs = {} + for name, dependency in self.kwargs.items(): + if isinstance(dependency, UpstreamDependency): + dependencies[name] = dependency.source + elif isinstance(dependency, LiteralDependency): + resolved_kwargs[name] = dependency.value + return dependencies, resolved_kwargs + + def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node]: + """Generates two nodes: + 1. A node that loads the data from the data source, and returns that + metadata + 2. A node that takes the data from the data source, injects it into, and runs, the function. + + :param fn: The function to decorate. + :param config: The configuration to use. + :return: The resolved nodes + """ + loader_cls = self._resolve_loader_class(fn) + loader_factory = LoaderFactory(loader_cls, **self.kwargs) + # dependencies is a map from param name -> source name + # we use this to pass the right arguments to the loader. + dependencies, resolved_kwargs = self.resolve_kwargs() + # we need to invert the dependencies so that we can pass + # the right argument to the loader + dependencies_inverted = {v: k for k, v in dependencies.items()} + inject_parameter, load_type = self._get_inject_parameter(fn) + + def load_data( + __loader_factory: LoaderFactory = loader_factory, + __load_type: Type[Type] = load_type, + __resolved_kwargs=resolved_kwargs, + __dependencies=dependencies_inverted, + __optional_params=loader_cls.get_optional_arguments(), + **input_kwargs: Any, + ) -> Tuple[load_type, Dict[str, Any]]: + input_args_with_fixed_dependencies = { + __dependencies.get(key, key): value for key, value in input_kwargs.items() + } + kwargs = {**__resolved_kwargs, **input_args_with_fixed_dependencies} + data_loader = __loader_factory.create_loader(**kwargs) + return data_loader.load_data(load_type) + + def get_input_type_key(key: str) -> str: + return key if key not in dependencies else dependencies[key] + + input_types = { + get_input_type_key(key): (Any, DependencyType.REQUIRED) + for key in loader_cls.get_required_arguments() + } + input_types.update( + { + (get_input_type_key(key) if key not in dependencies else dependencies[key]): ( + Any, + DependencyType.OPTIONAL, + ) + for key in loader_cls.get_optional_arguments() + } + ) + # Take out all the resolved kwargs, as they are not dependencies, and will be filled out + # later + input_types = { + key: value for key, value in input_types.items() if key not in resolved_kwargs + } + + # the loader node is the node that loads the data from the data source. + loader_node = node.Node( + name=f"{inject_parameter}", + callabl=load_data, + typ=Tuple[Dict[str, Any], load_type], + input_types=input_types, + namespace=("load_data", fn.__name__), + ) + + # the inject node is the node that takes the data from the data source, and injects it into + # the function. + + def inject_function(**kwargs): + new_kwargs = kwargs.copy() + new_kwargs[inject_parameter] = kwargs[loader_node.name][0] + del new_kwargs[loader_node.name] + return fn(**new_kwargs) + + raw_node = node.Node.from_fn(fn) + new_input_types = { + (key if key != inject_parameter else loader_node.name): loader_node.type + for key, value in raw_node.input_types.items() + } + data_node = raw_node.copy_with( + input_types=new_input_types, + callabl=inject_function, + ) + return [loader_node, data_node] + + def _get_inject_parameter(self, fn: Callable) -> Tuple[str, Type[Type]]: + """Gets the name of the parameter to inject the data into. + + :param fn: The function to decorate. + :return: The name of the parameter to inject the data into. + """ + sig = inspect.signature(fn) + if self.inject is None: + if len(sig.parameters) != 1: + raise InvalidDecoratorException( + f"If you have multiple parameters in the signature, " + f"you must pass `inject_` to the load_from decorator for " + f"function: {fn.__qualname__}" + ) + inject = list(sig.parameters.keys())[0] + + else: + if self.inject not in sig.parameters: + raise InvalidDecoratorException( + f"Invalid inject parameter: {self.inject} for fn: {fn.__qualname__}" + ) + inject = self.inject + return inject, typing.get_type_hints(fn)[inject] + + def validate(self, fn: Callable): + """Validates the decorator. Currently this just cals the get_inject_parameter and + cascades the error which is all we know at validation time. + + :param fn: + :return: + """ + self._get_inject_parameter(fn) + cls = self._resolve_loader_class(fn) + loader_factory = LoaderFactory(cls, **self.kwargs) + loader_factory.validate() + + def _resolve_loader_class(self, fn: Callable) -> Type[DataLoader]: + """Resolves the loader class for a function. This will return the most recently + registered loader class that applies to the injection type, hence the reversed order. + + :param fn: Function to inject the loaded data into. + :return: The loader class to use. + """ + param, type_ = self._get_inject_parameter(fn) + for loader_cls in reversed(self.loader_classes): + if loader_cls.applies_to(type_): + return loader_cls + + raise InvalidDecoratorException( + f"No loader class found for type: {type_} specified by " + f"parameter: {param} in function: {fn.__qualname__}" + ) + + +class load_from__meta__(type): + def __getattr__(cls, item: str): + if item in ADAPTER_REGISTRY: + return load_from.decorator_factory(ADAPTER_REGISTRY[item]) + return super().__getattribute__(item) + + +class load_from(metaclass=load_from__meta__): + def __call__(self, *args, **kwargs): + return LoadFromDecorator(*args, **kwargs) + + @classmethod + def decorator_factory( + cls, loaders: typing.Sequence[Type[DataLoader]] + ) -> Callable[..., LoadFromDecorator]: + """Effectively a partial function for the load_from decorator. Broken into its own (rather than + using functools.partial) as it is a little clearer to parse. + + :param loaders: Options of data loader classes to use + :return: The data loader decorator. + """ + + def create_decorator( + __loaders=tuple(loaders), inject_=None, **kwargs: ParametrizedDependency + ): + return LoadFromDecorator(__loaders, inject_=inject_, **kwargs) + + return create_decorator diff --git a/hamilton/io/__init__.py b/hamilton/io/__init__.py new file mode 100644 index 000000000..85daf7ccd --- /dev/null +++ b/hamilton/io/__init__.py @@ -0,0 +1,15 @@ +import logging + +from hamilton.io.default_data_loaders import DATA_LOADERS +from hamilton.registry import register_adapter + +logger = logging.getLogger(__name__) + +registered = False +# Register all the default ones +if not registered: + logger.debug(f"Registering default data loaders: {DATA_LOADERS}") + for data_loader in DATA_LOADERS: + register_adapter(data_loader) + +registered = True diff --git a/hamilton/io/data_loaders.py b/hamilton/io/data_loaders.py new file mode 100644 index 000000000..ce01c7d50 --- /dev/null +++ b/hamilton/io/data_loaders.py @@ -0,0 +1,92 @@ +import abc +import dataclasses +import typing +from typing import Any, Dict, Tuple, Type, TypeVar + +LoadType = TypeVar("LoadType") + + +class DataLoader(abc.ABC): + """Base class for data loaders. Data loaders are used to load data from a data source. + Note that they are inherently polymorphic -- they declare what type(s) they can load to, + and may choose to load differently depending on the type they are loading to. + + Note that this is not yet a public-facing API -- the current set of data loaders will + be managed by the library, and the user will not be able to create their own. + + We intend to change this and provide an extensible user-facing API, + but if you subclass this, beware! It might change. + """ + + @classmethod + @abc.abstractmethod + def applies_to(cls, type_: Type[Type]) -> bool: + """Tells whether or not this data loader can load to a specific type. + For instance, a CSV data loader might be able to load to a dataframe, + a json, but not an integer. + + This is a classmethod as it will be easier to validate, and we have to + construct this, delayed, with a factory. + + :param type_: Candidate type + :return: True if this data loader can load to the type, False otherwise. + """ + pass + + @abc.abstractmethod + def load_data(self, type_: Type[LoadType]) -> Tuple[LoadType, Dict[str, Any]]: + """Loads the data from the data source. + Note this uses the constructor parameters to determine + how to load the data. + + :return: The type specified + """ + pass + + @classmethod + @abc.abstractmethod + def name(cls) -> str: + """Returns the name of the data loader. This is used to register the data loader + with the load_from decorator. + + :return: The name of the data loader. + """ + pass + + @classmethod + def _ensure_dataclass(cls): + if not dataclasses.is_dataclass(cls): + raise TypeError( + f"DataLoader subclasses must be dataclasses. {cls.__qualname__} is not." + f" Did you forget to add @dataclass?" + ) + + @classmethod + def get_required_arguments(cls) -> Dict[str, Type[Type]]: + """Gives the required arguments for the class. + Note that this just uses the type hints from the dataclass. + + :return: The required arguments for the class. + """ + cls._ensure_dataclass() + type_hints = typing.get_type_hints(cls) + return { + field.name: type_hints.get(field.name) + for field in dataclasses.fields(cls) + if field.default == dataclasses.MISSING + } + + @classmethod + def get_optional_arguments(cls) -> Dict[str, Tuple[Type[Type], Any]]: + """Gives the optional arguments for the class. + Note that this just uses the type hints from the dataclass. + + :return: The optional arguments for the class. + """ + cls._ensure_dataclass() + type_hints = typing.get_type_hints(cls) + return { + field.name: type_hints.get(field.name) + for field in dataclasses.fields(cls) + if field.default != dataclasses.MISSING + } diff --git a/hamilton/io/default_data_loaders.py b/hamilton/io/default_data_loaders.py new file mode 100644 index 000000000..5e001aca4 --- /dev/null +++ b/hamilton/io/default_data_loaders.py @@ -0,0 +1,119 @@ +import dataclasses +import json +import os +import pickle +from datetime import datetime +from typing import Any, Dict, Tuple, Type + +from hamilton.htypes import custom_subclass_check +from hamilton.io.data_loaders import DataLoader, LoadType + + +def get_file_loading_metadata(path: str) -> Dict[str, Any]: + """Gives metadata from loading a file. + This includes: + - the file size + - the file path + - the last modified time + - the current time + """ + return { + "file_size": os.path.getsize(path), + "file_path": path, + "file_last_modified": os.path.getmtime(path), + "file_loaded_at": datetime.now().utcnow().timestamp(), + } + + +@dataclasses.dataclass +class JSONDataLoader(DataLoader): + path: str + + @classmethod + def applies_to(cls, type_: Type[Type]) -> bool: + return custom_subclass_check(type_, dict) + + def load_data(self, type_: Type[LoadType]) -> Tuple[LoadType, Dict[str, Any]]: + with open(self.path, "r") as f: + return json.load(f), get_file_loading_metadata(self.path) + + @classmethod + def name(cls) -> str: + return "json" + + +@dataclasses.dataclass +class LiteralValueDataLoader(DataLoader): + value: Any + + @classmethod + def applies_to(cls, type_: Type[Type]) -> bool: + return True + + def load_data(self, type_: Type[LoadType]) -> Tuple[LoadType, Dict[str, Any]]: + return self.value, {} + + @classmethod + def name(cls) -> str: + return "literal" + + +@dataclasses.dataclass +class RawFileDataLoader(DataLoader): + path: str + encoding: str = "utf-8" + + @classmethod + def applies_to(cls, type_: Type[Type]) -> bool: + return custom_subclass_check(type_, str) + + def load_data(self, type_: Type[LoadType]) -> Tuple[LoadType, Dict[str, Any]]: + with open(self.path, "r", encoding=self.encoding) as f: + return f.read(), get_file_loading_metadata(self.path) + + @classmethod + def name(cls) -> str: + return "file" + + +@dataclasses.dataclass +class PickleLoader(DataLoader): + path: str + + @classmethod + def applies_to(cls, type_: Type[Type]) -> bool: + return True # no way to know beforehand + + def load_data(self, type_: Type[LoadType]) -> Tuple[LoadType, Dict[str, Any]]: + with open(self.path, "rb") as f: + return pickle.load(f), get_file_loading_metadata(self.path) + + @classmethod + def name(cls) -> str: + return "pickle" + + +@dataclasses.dataclass +class EnvVarDataLoader(DataLoader): + + names: Tuple[str, ...] + + @classmethod + def applies_to(cls, type_: Type[Type]) -> bool: + return custom_subclass_check(type_, dict) + + def load_data(self, type_: Type[LoadType]) -> Tuple[LoadType, Dict[str, Any]]: + return {name: os.environ[name] for name in self.names}, {} + + @classmethod + def name(cls) -> str: + return "environment" + + +DATA_LOADERS = [ + JSONDataLoader, + LiteralValueDataLoader, + RawFileDataLoader, + PickleLoader, + EnvVarDataLoader, +] diff --git a/hamilton/registry.py b/hamilton/registry.py index 17ef1d004..438b87e22 100644 --- a/hamilton/registry.py +++ b/hamilton/registry.py @@ -1,3 +1,4 @@ +import collections import functools import importlib import logging @@ -84,3 +85,15 @@ def load_extension(plugin_module: str): mod, f"fill_with_scalar_{plugin_module}" ), f"Error extension missing fill_with_scalar_{plugin_module}" logger.info(f"Detected {plugin_module} and successfully loaded Hamilton extensions.") + + +ADAPTER_REGISTRY = collections.defaultdict(list) + + +def register_adapter(adapter: Any): + """Registers a adapter. Note that the type is any, + because we can't import it here due to circular imports. + + :param adapter: the adapter to register. + """ + ADAPTER_REGISTRY[adapter.name()].append(adapter) diff --git a/tests/function_modifiers/test_adapters.py b/tests/function_modifiers/test_adapters.py new file mode 100644 index 000000000..b1be944b7 --- /dev/null +++ b/tests/function_modifiers/test_adapters.py @@ -0,0 +1,299 @@ +import dataclasses +from collections import Counter +from typing import Any, Dict, Tuple, Type + +import pytest + +from hamilton import ad_hoc_utils, base, driver, graph +from hamilton.function_modifiers import base as fm_base +from hamilton.function_modifiers import source, value +from hamilton.function_modifiers.adapters import LoadFromDecorator, load_from +from hamilton.io.data_loaders import DataLoader, LoadType +from hamilton.registry import ADAPTER_REGISTRY + + +def test_default_adapters_are_available(): + assert len(ADAPTER_REGISTRY) > 0 + + +def test_default_adapters_are_registered_once(): + assert "json" in ADAPTER_REGISTRY + count_unique = { + key: Counter([value.__class__.__qualname__ for value in values]) + for key, values in ADAPTER_REGISTRY.items() + } + for key, value_ in count_unique.items(): + for impl, count in value_.items(): + assert count == 1, ( + f"Adapter registered multiple times for {impl}. This should not" + f" happen, as items should just be registered once." + ) + + +@dataclasses.dataclass +class MockDataLoader(DataLoader): + required_param: int + required_param_2: int + required_param_3: str + default_param: int = 4 + default_param_2: int = 5 + default_param_3: str = "6" + + @classmethod + def applies_to(cls, type_: Type[Type]) -> bool: + return issubclass(type_, int) + + def load_data(self, type_: Type[int]) -> Tuple[int, Dict[str, Any]]: + return ..., {"required_param": self.required_param, "default_param": self.default_param} + + @classmethod + def name(cls) -> str: + return "mock" + + +def test_load_from_decorator_resolve_kwargs(): + decorator = LoadFromDecorator( + [MockDataLoader], + required_param=source("1"), + required_param_2=value(2), + required_param_3=value("3"), + default_param=source("4"), + default_param_2=value(5), + ) + + dependency_kwargs, literal_kwargs = decorator.resolve_kwargs() + assert dependency_kwargs == {"required_param": "1", "default_param": "4"} + assert literal_kwargs == {"required_param_2": 2, "required_param_3": "3", "default_param_2": 5} + + +def test_load_from_decorator_validate_succeeds(): + decorator = LoadFromDecorator( + [MockDataLoader], + required_param=source("1"), + required_param_2=value(2), + required_param_3=value("3"), + ) + + def fn(injected_data: int) -> int: + return injected_data + + decorator.validate(fn) + + +def test_load_from_decorator_validate_succeeds_with_inject(): + decorator = LoadFromDecorator( + [MockDataLoader], + inject_="injected_data", + required_param=source("1"), + required_param_2=value(2), + required_param_3=value("3"), + ) + + def fn(injected_data: int, dependent_data: int) -> int: + return injected_data + dependent_data + + decorator.validate(fn) + + +def test_load_from_decorator_validate_fails_dont_know_which_param_to_inject(): + decorator = LoadFromDecorator( + [MockDataLoader], + required_param=source("1"), + required_param_2=value(2), + required_param_3=value("3"), + ) + + def fn(injected_data: int, other_possible_injected_data: int) -> int: + return injected_data + other_possible_injected_data + + with pytest.raises(fm_base.InvalidDecoratorException): + decorator.validate(fn) + + +def test_load_from_decorator_validate_fails_inject_not_in_fn(): + decorator = LoadFromDecorator( + [MockDataLoader], + inject_="injected_data", + required_param=source("1"), + required_param_2=value(2), + required_param_3=value("3"), + ) + + def fn(dependent_data: int) -> int: + return dependent_data + + with pytest.raises(fm_base.InvalidDecoratorException): + decorator.validate(fn) + + +@dataclasses.dataclass +class StringDataLoader(DataLoader): + @classmethod + def applies_to(cls, type_: Type[Type]) -> bool: + return issubclass(str, type_) + + def load_data(self, type_: Type[LoadType]) -> Tuple[LoadType, Dict[str, Any]]: + return "foo", {"loader": "string_data_loader"} + + @classmethod + def name(cls) -> str: + return "dummy" + + +@dataclasses.dataclass +class IntDataLoader(DataLoader): + @classmethod + def applies_to(cls, type_: Type[Type]) -> bool: + return issubclass(int, type_) + + def load_data(self, type_: Type[LoadType]) -> Tuple[LoadType, Dict[str, Any]]: + return 1, {"loader": "int_data_loader"} + + @classmethod + def name(cls) -> str: + return "dummy" + + +@dataclasses.dataclass +class IntDataLoaderClass2(DataLoader): + @classmethod + def applies_to(cls, type_: Type[Type]) -> bool: + return issubclass(int, type_) + + def load_data(self, type_: Type[LoadType]) -> Tuple[LoadType, Dict[str, Any]]: + return 2, {"loader": "int_data_loader_class_2"} + + @classmethod + def name(cls) -> str: + return "dummy" + + +def test_validate_fails_incorrect_type(): + decorator = LoadFromDecorator( + [StringDataLoader, IntDataLoader], + ) + + def fn_str_inject(injected_data: str) -> str: + return injected_data + + def fn_int_inject(injected_data: int) -> int: + return injected_data + + def fn_bool_inject(injected_data: bool) -> bool: + return injected_data + + # This is valid as there is one parameter and its a type that the decorator supports + decorator.validate(fn_str_inject) + + # This is valid as there is one parameter and its a type that the decorator supports + decorator.validate(fn_int_inject) + + # This is invalid as there is one parameter and it is not a type that the decorator supports + with pytest.raises(fm_base.InvalidDecoratorException): + decorator.validate(fn_bool_inject) + + +def test_validate_selects_correct_type(): + decorator = LoadFromDecorator( + [StringDataLoader, IntDataLoader], + ) + + def fn_str_inject(injected_data: str) -> str: + return injected_data + + def fn_int_inject(injected_data: int) -> int: + return injected_data + + def fn_bool_inject(injected_data: bool) -> bool: + return injected_data + + # This is valid as there is one parameter and its a type that the decorator supports + decorator.validate(fn_str_inject) + + # This is valid as there is one parameter and its a type that the decorator supports + decorator.validate(fn_int_inject) + + # This is invalid as there is one parameter and it is not a type that the decorator supports + with pytest.raises(fm_base.InvalidDecoratorException): + decorator.validate(fn_bool_inject) + + +# Note that this tests an internal API, but we would like to test this to ensure +# class selection is correct +def test_load_from_resolves_correct_class(): + decorator = LoadFromDecorator( + [StringDataLoader, IntDataLoader, IntDataLoaderClass2], + ) + + def fn_str_inject(injected_data: str) -> str: + return injected_data + + def fn_int_inject(injected_data: int) -> int: + return injected_data + + def fn_bool_inject(injected_data: bool) -> bool: + return injected_data + + # This is valid as there is one parameter and its a type that the decorator supports + assert decorator._resolve_loader_class(fn_str_inject) == StringDataLoader + assert decorator._resolve_loader_class(fn_int_inject) == IntDataLoaderClass2 + # This is invalid as there is one parameter and it is not a type that the decorator supports + with pytest.raises(fm_base.InvalidDecoratorException): + decorator._resolve_loader_class(fn_bool_inject) + + +# End-to-end tests are probably cleanest +# We've done a bunch of tests of internal structures for other decorators, +# but that leaves the testing brittle +# We don't test the driver, we just use the function_graph to tests the nodes +def test_load_from_decorator_end_to_end(): + @LoadFromDecorator( + [StringDataLoader, IntDataLoader, IntDataLoaderClass2], + ) + def fn_str_inject(injected_data: str) -> str: + return injected_data + + config = {} + adapter = base.SimplePythonGraphAdapter(base.DictResult()) + fg = graph.FunctionGraph( + ad_hoc_utils.create_temporary_module(fn_str_inject), config=config, adapter=adapter + ) + result = fg.execute(inputs={}, nodes=fg.nodes.values()) + assert result["fn_str_inject"] == "foo" + assert result["load_data.fn_str_inject.injected_data"] == ( + "foo", + {"loader": "string_data_loader"}, + ) + + +@pytest.mark.parametrize( + "source_", + [ + value("tests/resources/data/test_load_from_data.json"), + source("test_data"), + ], +) +def test_load_from_decorator_json_file(source_): + @load_from.json(path=source_) + def raw_json_data(data: Dict[str, Any]) -> Dict[str, Any]: + return data + + def number_employees(raw_json_data: Dict[str, Any]) -> int: + return len(raw_json_data["employees"]) + + def sum_age(raw_json_data: Dict[str, Any]) -> float: + return sum([employee["age"] for employee in raw_json_data["employees"]]) + + def mean_age(sum_age: float, number_employees: int) -> float: + return sum_age / number_employees + + config = {} + dr = driver.Driver( + config, + ad_hoc_utils.create_temporary_module(raw_json_data, number_employees, sum_age, mean_age), + adapter=base.SimplePythonGraphAdapter(base.DictResult()), + ) + result = dr.execute( + ["mean_age"], inputs={"test_data": "tests/resources/data/test_load_from_data.json"} + ) + assert result["mean_age"] - 32.33333 < 0.0001 diff --git a/tests/materialization/test_data_loaders.py b/tests/materialization/test_data_loaders.py new file mode 100644 index 000000000..71fa138e1 --- /dev/null +++ b/tests/materialization/test_data_loaders.py @@ -0,0 +1,26 @@ +import dataclasses +from typing import Any, Dict, Tuple, Type + +from hamilton.io.data_loaders import DataLoader, LoadType + + +@dataclasses.dataclass +class MockDataLoader(DataLoader): + required_param: int + default_param: int = 1 + + @classmethod + def applies_to(cls, type_: Type[Type]) -> bool: + return True + + def load_data(self, type_: Type[LoadType]) -> Tuple[LoadType, Dict[str, Any]]: + pass + + @classmethod + def name(cls) -> str: + pass + + +def test_data_loader_get_required_params(): + assert MockDataLoader.get_required_arguments() == {"required_param": int} + assert MockDataLoader.get_optional_arguments() == {"default_param": int} diff --git a/tests/resources/data/test_load_from_data.json b/tests/resources/data/test_load_from_data.json new file mode 100644 index 000000000..8f15f58e9 --- /dev/null +++ b/tests/resources/data/test_load_from_data.json @@ -0,0 +1,25 @@ +{ + "employees": [ + { + "firstName": "John", + "lastName": "Doe", + "age": 25, + "department": "IT", + "email": "john.doe@example.com" + }, + { + "firstName": "Jane", + "lastName": "Smith", + "age": 32, + "department": "HR", + "email": "jane.smith@example.com" + }, + { + "firstName": "Bob", + "lastName": "Johnson", + "age": 40, + "department": "Marketing", + "email": "bob.johnson@example.com" + } + ] +}