diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py index fa476be89..6c2a793c6 100644 --- a/hamilton/function_modifiers/adapters.py +++ b/hamilton/function_modifiers/adapters.py @@ -1,33 +1,39 @@ import inspect import typing -from typing import Any, Callable, Dict, List, Tuple, Type +from typing import Any, Callable, Collection, Dict, List, Tuple, Type from hamilton import node -from hamilton.function_modifiers.base import InvalidDecoratorException, NodeCreator +from hamilton.function_modifiers.base import ( + InvalidDecoratorException, + NodeCreator, + SingleNodeNodeTransformer, +) from hamilton.function_modifiers.dependencies import ( LiteralDependency, ParametrizedDependency, UpstreamDependency, ) -from hamilton.io.data_loaders import DataLoader +from hamilton.io.data_adapters import AdapterCommon, DataLoader, DataSaver from hamilton.node import DependencyType -from hamilton.registry import ADAPTER_REGISTRY +from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY -class LoaderFactory: +class AdapterFactory: """Factory for data loaders. This handles the fact that we pass in source(...) and value(...) parameters to the data loaders.""" - def __init__(self, loader_cls: Type[DataLoader], **kwargs: ParametrizedDependency): - """Initializes a loader factory. This takes in parameterized dependencies + def __init__(self, adapter_cls: Type[AdapterCommon], **kwargs: ParametrizedDependency): + """Initializes an adapter factory. This takes in parameterized dependencies and stores them for later resolution. Note that this is not strictly necessary -- we could easily put this in the + decorator, but I wanted to separate out/consolidate the logic between data savers and data + loaders. - :param loader_cls: Class of the loader to create. + :param adapter_cls: Class of the loader to create. :param kwargs: Keyword arguments to pass to the loader, as parameterized dependencies. """ - self.loader_cls = loader_cls + self.adapter_cls = adapter_cls self.kwargs = kwargs self.validate() @@ -37,25 +43,62 @@ def validate(self): :raises InvalidDecoratorException: If the arguments are invalid. """ - required_args = self.loader_cls.get_required_arguments() - optional_args = self.loader_cls.get_optional_arguments() + required_args = self.adapter_cls.get_required_arguments() + optional_args = self.adapter_cls.get_optional_arguments() missing_params = set(required_args.keys()) - set(self.kwargs.keys()) extra_params = ( set(self.kwargs.keys()) - set(required_args.keys()) - set(optional_args.keys()) ) if len(missing_params) > 0: raise InvalidDecoratorException( - f"Missing required parameters for loader : {self.loader_cls}: {missing_params}. " + f"Missing required parameters for adapter : {self.adapter_cls}: {missing_params}. " f"Required parameters/types are: {required_args}. Optional parameters/types are: " f"{optional_args}. " ) if len(extra_params) > 0: raise InvalidDecoratorException( - f"Extra parameters for loader: {self.loader_cls} {extra_params}" + f"Extra parameters for loader: {self.adapter_cls} {extra_params}" ) def create_loader(self, **resolved_kwargs: Any) -> DataLoader: - return self.loader_cls(**resolved_kwargs) + if not self.adapter_cls.can_load(): + raise InvalidDecoratorException(f"Adapter {self.adapter_cls} cannot load data.") + return self.adapter_cls(**resolved_kwargs) + + def create_saver(self, **resolved_kwargs: Any) -> DataSaver: + if not self.adapter_cls.can_save(): + raise InvalidDecoratorException(f"Adapter {self.adapter_cls} cannot save data.") + return self.adapter_cls(**resolved_kwargs) + + +def resolve_kwargs(kwargs: Dict[str, Any]) -> Tuple[Dict[str, str], Dict[str, Any]]: + """Resolves kwargs to a list of dependencies, and a dictionary of name + to resolved literal values. + + :return: A tuple of the dependencies, and the resolved literal kwargs. + """ + dependencies = {} + resolved_kwargs = {} + for name, dependency in kwargs.items(): + if isinstance(dependency, UpstreamDependency): + dependencies[name] = dependency.source + elif isinstance(dependency, LiteralDependency): + resolved_kwargs[name] = dependency.value + return dependencies, resolved_kwargs + + +def resolve_adapter_class( + type_: Type[Type], loader_classes: List[Type[AdapterCommon]] +) -> Type[AdapterCommon]: + """Resolves the loader class for a function. This will return the most recently + registered loader class that applies to the injection type, hence the reversed order. + + :param fn: Function to inject the loaded data into. + :return: The loader class to use. + """ + for loader_cls in reversed(loader_classes): + if loader_cls.applies_to(type_): + return loader_cls class LoadFromDecorator(NodeCreator): @@ -76,21 +119,6 @@ def __init__( self.kwargs = kwargs self.inject = inject_ - def resolve_kwargs(self) -> Tuple[Dict[str, str], Dict[str, Any]]: - """Resolves kwargs to a list of dependencies, and a dictionary of name - to resolved literal values. - - :return: A tuple of the dependencies, and the resolved literal kwargs. - """ - dependencies = {} - resolved_kwargs = {} - for name, dependency in self.kwargs.items(): - if isinstance(dependency, UpstreamDependency): - dependencies[name] = dependency.source - elif isinstance(dependency, LiteralDependency): - resolved_kwargs[name] = dependency.value - return dependencies, resolved_kwargs - def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node]: """Generates two nodes: 1. A node that loads the data from the data source, and returns that + metadata @@ -100,18 +128,27 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node :param config: The configuration to use. :return: The resolved nodes """ - loader_cls = self._resolve_loader_class(fn) - loader_factory = LoaderFactory(loader_cls, **self.kwargs) + inject_parameter, type_ = self._get_inject_parameter(fn) + loader_cls = resolve_adapter_class( + type_, + self.loader_classes, + ) + if loader_cls is None: + raise InvalidDecoratorException( + f"No loader class found for type: {type_} specified by " + f"parameter: {inject_parameter} in function: {fn.__qualname__}" + ) + loader_factory = AdapterFactory(loader_cls, **self.kwargs) # dependencies is a map from param name -> source name # we use this to pass the right arguments to the loader. - dependencies, resolved_kwargs = self.resolve_kwargs() + dependencies, resolved_kwargs = resolve_kwargs(self.kwargs) # we need to invert the dependencies so that we can pass # the right argument to the loader dependencies_inverted = {v: k for k, v in dependencies.items()} inject_parameter, load_type = self._get_inject_parameter(fn) def load_data( - __loader_factory: LoaderFactory = loader_factory, + __loader_factory: AdapterFactory = loader_factory, __load_type: Type[Type] = load_type, __resolved_kwargs=resolved_kwargs, __dependencies=dependencies_inverted, @@ -207,42 +244,30 @@ def validate(self, fn: Callable): :param fn: :return: """ - self._get_inject_parameter(fn) - cls = self._resolve_loader_class(fn) - loader_factory = LoaderFactory(cls, **self.kwargs) + inject_parameter, type_ = self._get_inject_parameter(fn) + cls = resolve_adapter_class(type_, self.loader_classes) + if cls is None: + raise InvalidDecoratorException( + f"No loader class found for type: {type_} specified by " + f"parameter: {inject_parameter} in function: {fn.__qualname__}" + ) + loader_factory = AdapterFactory(cls, **self.kwargs) loader_factory.validate() - def _resolve_loader_class(self, fn: Callable) -> Type[DataLoader]: - """Resolves the loader class for a function. This will return the most recently - registered loader class that applies to the injection type, hence the reversed order. - - :param fn: Function to inject the loaded data into. - :return: The loader class to use. - """ - param, type_ = self._get_inject_parameter(fn) - for loader_cls in reversed(self.loader_classes): - if loader_cls.applies_to(type_): - return loader_cls - - raise InvalidDecoratorException( - f"No loader class found for type: {type_} specified by " - f"parameter: {param} in function: {fn.__qualname__}" - ) - class load_from__meta__(type): def __getattr__(cls, item: str): - if item in ADAPTER_REGISTRY: - return load_from.decorator_factory(ADAPTER_REGISTRY[item]) + if item in LOADER_REGISTRY: + return load_from.decorator_factory(LOADER_REGISTRY[item]) try: return super().__getattribute__(item) except AttributeError as e: raise AttributeError( f"No loader named: {item} available for {cls.__name__}. " - f"Available loaders are: {ADAPTER_REGISTRY.keys()}. " + f"Available loaders are: {LOADER_REGISTRY.keys()}. " f"If you've gotten to this point, you either (1) spelled the " f"loader name wrong, (2) are trying to use a loader that does" - f"not exist (yet), or (3) have implemented it and " + f"not exist (yet)" ) from e @@ -329,3 +354,130 @@ def create_decorator( return LoadFromDecorator(__loaders, inject_=inject_, **kwargs) return create_decorator + + +class save_to__meta__(type): + def __getattr__(cls, item: str): + if item in SAVER_REGISTRY: + return save_to.decorator_factory(SAVER_REGISTRY[item]) + try: + return super().__getattribute__(item) + except AttributeError as e: + raise AttributeError( + f"No saver named: {item} available for {cls.__name__}. " + f"Available data savers are: {SAVER_REGISTRY.keys()}. " + f"If you've gotten to this point, you either (1) spelled the " + f"loader name wrong, (2) are trying to use a saver that does" + f"not exist (yet)" + ) from e + + +class SaveToDecorator(SingleNodeNodeTransformer): + def __init__( + self, + saver_classes_: typing.Sequence[Type[DataSaver]], + artifact_name_: str = None, + **kwargs: ParametrizedDependency, + ): + self.artifact_name = artifact_name_ + self.saver_classes = saver_classes_ + self.kwargs = kwargs + + def transform_node( + self, node_: node.Node, config: Dict[str, Any], fn: Callable + ) -> Collection[node.Node]: + artifact_name = self.artifact_name + artifact_namespace = () + + if artifact_name is None: + artifact_name = node_.name + artifact_namespace = ("save",) + + type_ = node_.type + saver_cls = resolve_adapter_class( + type_, + self.saver_classes, + ) + if saver_cls is None: + raise InvalidDecoratorException( + f"No saver class found for type: {type_} specified by " + f"output type: {type_} in node: {node_.name} generated by " + f"function: {fn.__qualname__}." + ) + + adapter_factory = AdapterFactory(saver_cls, **self.kwargs) + dependencies, resolved_kwargs = resolve_kwargs(self.kwargs) + dependencies_inverted = {v: k for k, v in dependencies.items()} + + def save_data( + __adapter_factory=adapter_factory, + __dependencies=dependencies_inverted, + __resolved_kwargs=resolved_kwargs, + __data_node_name=node_.name, + **input_kwargs, + ) -> Dict[str, Any]: + input_args_with_fixed_dependencies = { + __dependencies.get(key, key): value for key, value in input_kwargs.items() + } + kwargs = {**__resolved_kwargs, **input_args_with_fixed_dependencies} + data_to_save = kwargs[__data_node_name] + kwargs = {k: v for k, v in kwargs.items() if k != __data_node_name} + data_saver = __adapter_factory.create_saver(**kwargs) + return data_saver.save_data(data_to_save) + + def get_input_type_key(key: str) -> str: + return key if key not in dependencies else dependencies[key] + + input_types = { + get_input_type_key(key): (Any, DependencyType.REQUIRED) + for key in saver_cls.get_required_arguments() + } + input_types.update( + { + (get_input_type_key(key) if key not in dependencies else dependencies[key]): ( + Any, + DependencyType.OPTIONAL, + ) + for key in saver_cls.get_optional_arguments() + } + ) + # Take out all the resolved kwargs, as they are not dependencies, and will be filled out + # later + input_types = { + key: value for key, value in input_types.items() if key not in resolved_kwargs + } + input_types[node_.name] = (node_.type, DependencyType.REQUIRED) + + save_node = node.Node( + name=artifact_name, + callabl=save_data, + typ=Dict[str, Any], + input_types=input_types, + namespace=artifact_namespace, + ) + return [save_node, node_] + + def validate(self, fn: Callable): + pass + + +class save_to(metaclass=load_from__meta__): + def __call__(self, *args, **kwargs): + return LoadFromDecorator(*args, **kwargs) + + @classmethod + def decorator_factory( + cls, savers: typing.Sequence[Type[DataSaver]] + ) -> Callable[..., SaveToDecorator]: + """Effectively a partial function for the load_from decorator. Broken into its own ( + rather than using functools.partial) as it is a little clearer to parse. + + :param savers: Candidate data savers + :param loaders: Options of data loader classes to use + :return: The data loader decorator. + """ + + def create_decorator(__savers=tuple(savers), **kwargs: ParametrizedDependency): + return SaveToDecorator(__savers, **kwargs) + + return create_decorator diff --git a/hamilton/io/__init__.py b/hamilton/io/__init__.py index 85daf7ccd..7f2664783 100644 --- a/hamilton/io/__init__.py +++ b/hamilton/io/__init__.py @@ -1,6 +1,6 @@ import logging -from hamilton.io.default_data_loaders import DATA_LOADERS +from hamilton.io.default_data_loaders import DATA_ADAPTERS from hamilton.registry import register_adapter logger = logging.getLogger(__name__) @@ -8,8 +8,8 @@ registered = False # Register all the default ones if not registered: - logger.debug(f"Registering default data loaders: {DATA_LOADERS}") - for data_loader in DATA_LOADERS: + logger.debug(f"Registering default data loaders: {DATA_ADAPTERS}") + for data_loader in DATA_ADAPTERS: register_adapter(data_loader) registered = True diff --git a/hamilton/io/data_loaders.py b/hamilton/io/data_adapters.py similarity index 68% rename from hamilton/io/data_loaders.py rename to hamilton/io/data_adapters.py index 152c6026f..991266723 100644 --- a/hamilton/io/data_loaders.py +++ b/hamilton/io/data_adapters.py @@ -6,18 +6,7 @@ from hamilton.htypes import custom_subclass_check -class DataLoader(abc.ABC): - """Base class for data loaders. Data loaders are used to load data from a data source. - Note that they are inherently polymorphic -- they declare what type(s) they can load to, - and may choose to load differently depending on the type they are loading to. - - Note that this is not yet a public-facing API -- the current set of data loaders will - be managed by the library, and the user will not be able to create their own. - - We intend to change this and provide an extensible user-facing API, - but if you subclass this, beware! It might change. - """ - +class AdapterCommon(abc.ABC): @classmethod @abc.abstractmethod def load_targets(cls) -> Collection[Type]: @@ -53,16 +42,6 @@ def applies_to(cls, type_: Type[Type]) -> bool: return True return False - @abc.abstractmethod - def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]: - """Loads the data from the data source. - Note this uses the constructor parameters to determine - how to load the data. - - :return: The type specified - """ - pass - @classmethod @abc.abstractmethod def name(cls) -> str: @@ -110,3 +89,79 @@ def get_optional_arguments(cls) -> Dict[str, Tuple[Type[Type], Any]]: for field in dataclasses.fields(cls) if field.default != dataclasses.MISSING } + + @classmethod + def can_load(cls) -> bool: + """Returns whether this adapter can "load" data. + Subclasses are meant to implement this function to + tell the framework what to do with them. + + :return: + """ + return False + + @classmethod + def can_save(cls) -> bool: + """Returns whether this adapter can "save" data. + Subclasses are meant to implement this function to + tell the framework what to do with them. + + :return: + """ + return False + + +class DataLoader(AdapterCommon, abc.ABC): + """Base class for data loaders. Data loaders are used to load data from a data source. + Note that they are inherently polymorphic -- they declare what type(s) they can load to, + and may choose to load differently depending on the type they are loading to. + + Note that this is not yet a public-facing API -- the current set of data loaders will + be managed by the library, and the user will not be able to create their own. + + We intend to change this and provide an extensible user-facing API, + but if you subclass this, beware! It might change. + """ + + @abc.abstractmethod + def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]: + """Loads the data from the data source. + Note this uses the constructor parameters to determine + how to load the data. + + :return: The type specified + """ + pass + + @classmethod + def can_load(cls) -> bool: + return True + + +class DataSaver(AdapterCommon, abc.ABC): + """Base class for data savers. Data savers are used to save data to a data source. + Note that they are inherently polymorphic -- they declare what type(s) they can save from, + and may choose to save differently depending on the type they are saving from. + + Note that this is not yet a public-facing API -- the current set of data savers will + be managed by the library, and the user will not be able to create their own. + + We intend to change this and provide an extensible user-facing API, + but if you subclass this, beware! It might change. + """ + + @abc.abstractmethod + def save_data(self, data: Any) -> Dict[str, Any]: + """Saves the data to the data source. + Note this uses the constructor parameters to determine + how to save the data. + + :return: Any relevant metadata. This is up the the data saver, but will likely + include the URI, etc... This is going to be similar to the metadata returned + by the data loader in the loading tuple. + """ + pass + + @classmethod + def can_save(cls) -> bool: + return True diff --git a/hamilton/io/default_data_loaders.py b/hamilton/io/default_data_loaders.py index 1313c1e26..0c88ee71e 100644 --- a/hamilton/io/default_data_loaders.py +++ b/hamilton/io/default_data_loaders.py @@ -4,12 +4,12 @@ import pickle from typing import Any, Collection, Dict, Tuple, Type -from hamilton.io.data_loaders import DataLoader +from hamilton.io.data_adapters import DataLoader, DataSaver from hamilton.io.utils import get_file_loading_metadata @dataclasses.dataclass -class JSONDataLoader(DataLoader): +class JSONDataAdapter(DataLoader, DataSaver): path: str @classmethod @@ -24,25 +24,14 @@ def load_data(self, type_: Type) -> Tuple[dict, Dict[str, Any]]: def name(cls) -> str: return "json" - -@dataclasses.dataclass -class LiteralValueDataLoader(DataLoader): - value: Any - - @classmethod - def load_targets(cls) -> Collection[Type]: - return [Any] - - def load_data(self, type_: Type) -> Tuple[dict, Dict[str, Any]]: - return self.value, {} - - @classmethod - def name(cls) -> str: - return "literal" + def save_data(self, data: Any) -> Dict[str, Any]: + with open(self.path, "w") as f: + json.dump(data, f) + return get_file_loading_metadata(self.path) @dataclasses.dataclass -class RawFileDataLoader(DataLoader): +class RawFileDataLoader(DataLoader, DataSaver): path: str encoding: str = "utf-8" @@ -58,23 +47,33 @@ def load_targets(cls) -> Collection[Type]: def name(cls) -> str: return "file" + def save_data(self, data: Any) -> Dict[str, Any]: + with open(self.path, "w", encoding=self.encoding) as f: + f.write(data) + return get_file_loading_metadata(self.path) + @dataclasses.dataclass class PickleLoader(DataLoader): path: str - def load_data(self, type_: Type[dict]) -> Tuple[str, Dict[str, Any]]: - with open(self.path, "rb") as f: - return pickle.load(f), get_file_loading_metadata(self.path) - @classmethod def load_targets(cls) -> Collection[Type]: - return [str] + return [object] @classmethod def name(cls) -> str: return "pickle" + def load_data(self, type_: Type[dict]) -> Tuple[str, Dict[str, Any]]: + with open(self.path, "rb") as f: + return pickle.load(f), get_file_loading_metadata(self.path) + + def save_data(self, data: Any) -> Dict[str, Any]: + with open(self.path, "wb") as f: + pickle.dump(data, f) + return get_file_loading_metadata(self.path) + @dataclasses.dataclass class EnvVarDataLoader(DataLoader): @@ -92,8 +91,24 @@ def load_targets(cls) -> Collection[Type]: return [dict] -DATA_LOADERS = [ - JSONDataLoader, +@dataclasses.dataclass +class LiteralValueDataLoader(DataLoader): + value: Any + + @classmethod + def load_targets(cls) -> Collection[Type]: + return [Any] + + def load_data(self, type_: Type) -> Tuple[dict, Dict[str, Any]]: + return self.value, {} + + @classmethod + def name(cls) -> str: + return "literal" + + +DATA_ADAPTERS = [ + JSONDataAdapter, LiteralValueDataLoader, RawFileDataLoader, PickleLoader, diff --git a/hamilton/plugins/pandas_extensions.py b/hamilton/plugins/pandas_extensions.py index 03d3c03db..d505cdb1e 100644 --- a/hamilton/plugins/pandas_extensions.py +++ b/hamilton/plugins/pandas_extensions.py @@ -3,7 +3,7 @@ from typing import Any, Collection, Dict, Tuple, Type from hamilton.io import utils -from hamilton.io.data_loaders import DataLoader +from hamilton.io.data_adapters import DataLoader try: import pandas as pd diff --git a/hamilton/registry.py b/hamilton/registry.py index 438b87e22..0c548a5b7 100644 --- a/hamilton/registry.py +++ b/hamilton/registry.py @@ -87,7 +87,8 @@ def load_extension(plugin_module: str): logger.info(f"Detected {plugin_module} and successfully loaded Hamilton extensions.") -ADAPTER_REGISTRY = collections.defaultdict(list) +LOADER_REGISTRY = collections.defaultdict(list) +SAVER_REGISTRY = collections.defaultdict(list) def register_adapter(adapter: Any): @@ -96,4 +97,7 @@ def register_adapter(adapter: Any): :param adapter: the adapter to register. """ - ADAPTER_REGISTRY[adapter.name()].append(adapter) + if adapter.can_load(): + LOADER_REGISTRY[adapter.name()].append(adapter) + if adapter.can_save(): + SAVER_REGISTRY[adapter.name()].append(adapter) diff --git a/tests/function_modifiers/test_adapters.py b/tests/function_modifiers/test_adapters.py index c5ce5879d..ce98d46d8 100644 --- a/tests/function_modifiers/test_adapters.py +++ b/tests/function_modifiers/test_adapters.py @@ -1,27 +1,33 @@ import dataclasses from collections import Counter -from typing import Any, Collection, Dict, Tuple, Type +from typing import Any, Collection, Dict, List, Tuple, Type import pandas as pd import pytest -from hamilton import ad_hoc_utils, base, driver, graph +from hamilton import ad_hoc_utils, base, driver, graph, node from hamilton.function_modifiers import base as fm_base from hamilton.function_modifiers import source, value -from hamilton.function_modifiers.adapters import LoadFromDecorator, load_from -from hamilton.io.data_loaders import DataLoader -from hamilton.registry import ADAPTER_REGISTRY +from hamilton.function_modifiers.adapters import ( + LoadFromDecorator, + SaveToDecorator, + load_from, + resolve_adapter_class, + resolve_kwargs, +) +from hamilton.io.data_adapters import DataLoader, DataSaver +from hamilton.registry import LOADER_REGISTRY def test_default_adapters_are_available(): - assert len(ADAPTER_REGISTRY) > 0 + assert len(LOADER_REGISTRY) > 0 def test_default_adapters_are_registered_once(): - assert "json" in ADAPTER_REGISTRY + assert "json" in LOADER_REGISTRY count_unique = { key: Counter([value.__class__.__qualname__ for value in values]) - for key, values in ADAPTER_REGISTRY.items() + for key, values in LOADER_REGISTRY.items() } for key, value_ in count_unique.items(): for impl, count in value_.items(): @@ -53,8 +59,7 @@ def name(cls) -> str: def test_load_from_decorator_resolve_kwargs(): - decorator = LoadFromDecorator( - [MockDataLoader], + kwargs = dict( required_param=source("1"), required_param_2=value(2), required_param_3=value("3"), @@ -62,7 +67,7 @@ def test_load_from_decorator_resolve_kwargs(): default_param_2=value(5), ) - dependency_kwargs, literal_kwargs = decorator.resolve_kwargs() + dependency_kwargs, literal_kwargs = resolve_kwargs(kwargs) assert dependency_kwargs == {"required_param": "1", "default_param": "4"} assert literal_kwargs == {"required_param_2": 2, "required_param_3": "3", "default_param_2": 5} @@ -172,7 +177,7 @@ def load_targets(cls) -> Collection[Type]: @dataclasses.dataclass -class IntDataLoaderClass2(DataLoader): +class IntDataLoader2(DataLoader): @classmethod def load_targets(cls) -> Collection[Type]: return [int] @@ -237,9 +242,27 @@ def fn_bool_inject(injected_data: bool) -> bool: # Note that this tests an internal API, but we would like to test this to ensure # class selection is correct -def test_load_from_resolves_correct_class(): +@pytest.mark.parametrize( + "type_,classes,correct_class", + [ + (str, [StringDataLoader, IntDataLoader, IntDataLoader2], StringDataLoader), + (int, [StringDataLoader, IntDataLoader, IntDataLoader2], IntDataLoader2), + (int, [IntDataLoader2, IntDataLoader], IntDataLoader), + (int, [IntDataLoader, IntDataLoader2], IntDataLoader2), + (int, [StringDataLoader], None), + (str, [IntDataLoader], None), + (dict, [IntDataLoader], None), + ], +) +def test_resolve_correct_loader_class( + type_: Type[Type], classes: List[Type[DataLoader]], correct_class: Type[DataLoader] +): + assert resolve_adapter_class(type_, classes) == correct_class + + +def test_decorator_validate(): decorator = LoadFromDecorator( - [StringDataLoader, IntDataLoader, IntDataLoaderClass2], + [StringDataLoader, IntDataLoader, IntDataLoader2], ) def fn_str_inject(injected_data: str) -> str: @@ -252,11 +275,11 @@ def fn_bool_inject(injected_data: bool) -> bool: return injected_data # This is valid as there is one parameter and its a type that the decorator supports - assert decorator._resolve_loader_class(fn_str_inject) == StringDataLoader - assert decorator._resolve_loader_class(fn_int_inject) == IntDataLoaderClass2 + decorator.validate(fn_str_inject) + decorator.validate(fn_int_inject) # This is invalid as there is one parameter and it is not a type that the decorator supports with pytest.raises(fm_base.InvalidDecoratorException): - decorator._resolve_loader_class(fn_bool_inject) + decorator.validate(fn_bool_inject) # End-to-end tests are probably cleanest @@ -265,7 +288,7 @@ def fn_bool_inject(injected_data: bool) -> bool: # We don't test the driver, we just use the function_graph to tests the nodes def test_load_from_decorator_end_to_end(): @LoadFromDecorator( - [StringDataLoader, IntDataLoader, IntDataLoaderClass2], + [StringDataLoader, IntDataLoader, IntDataLoader2], ) def fn_str_inject(injected_data: str) -> str: return injected_data @@ -338,3 +361,47 @@ def df(data: pd.DataFrame) -> pd.DataFrame: ) assert result["df"].shape == (3, 5) assert result["df"].loc[0, "firstName"] == "John" + + +@dataclasses.dataclass +class MarkingSaver(DataSaver): + markers: set + more_markers: set + + def save_data(self, data: int) -> Dict[str, Any]: + self.markers.add(data) + self.more_markers.add(data) + return {} + + @classmethod + def load_targets(cls) -> Collection[Type]: + return [int] + + @classmethod + def name(cls) -> str: + return "marker" + + +def test_save_to_decorator(): + def fn() -> int: + return 1 + + marking_set = set() + marking_set_2 = set() + decorator = SaveToDecorator( + [MarkingSaver], + artifact_name_="save_fn", + markers=value(marking_set), + more_markers=source("more_markers"), + ) + node_ = node.Node.from_fn(fn) + nodes = decorator.transform_node(node_, {}, fn) + assert len(nodes) == 2 + nodes_by_name = {node_.name: node_ for node_ in nodes} + assert "save_fn" in nodes_by_name + assert "fn" in nodes_by_name + assert sorted(nodes_by_name["save_fn"].input_types.keys()) == ["fn", "more_markers"] + assert nodes_by_name["save_fn"](**{"fn": 1, "more_markers": marking_set_2}) == {} + # Check that the markers are updated, ensuring that the save_fn is called + assert marking_set_2 == {1} + assert marking_set == {1} diff --git a/tests/materialization/test_data_loaders.py b/tests/materialization/test_data_adapters.py similarity index 92% rename from tests/materialization/test_data_loaders.py rename to tests/materialization/test_data_adapters.py index 32c6d5a3c..7e81a1adc 100644 --- a/tests/materialization/test_data_loaders.py +++ b/tests/materialization/test_data_adapters.py @@ -1,7 +1,7 @@ import dataclasses from typing import Any, Collection, Dict, Tuple, Type -from hamilton.io.data_loaders import DataLoader +from hamilton.io.data_adapters import DataLoader @dataclasses.dataclass