Skip to content

Commit

Permalink
Implements basic data saving decorator as well
Browse files Browse the repository at this point in the history
This decorates the output of a function -- creating a node to append to
the end that saves it. This has similar functionality to data loaders,
and has been grouped togther under the data loading umbrella.
  • Loading branch information
elijahbenizzy committed Apr 4, 2023
1 parent aa8ba68 commit dffb1f3
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 101 deletions.
265 changes: 208 additions & 57 deletions hamilton/function_modifiers/adapters.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
import inspect
import typing
from typing import Any, Callable, Dict, List, Tuple, Type
from typing import Any, Callable, Collection, Dict, List, Tuple, Type

from hamilton import node
from hamilton.function_modifiers.base import InvalidDecoratorException, NodeCreator
from hamilton.function_modifiers.base import (
InvalidDecoratorException,
NodeCreator,
SingleNodeNodeTransformer,
)
from hamilton.function_modifiers.dependencies import (
LiteralDependency,
ParametrizedDependency,
UpstreamDependency,
)
from hamilton.io.data_loaders import DataLoader
from hamilton.io.data_adapters import AdapterCommon, DataLoader, DataSaver
from hamilton.node import DependencyType
from hamilton.registry import ADAPTER_REGISTRY
from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY


class LoaderFactory:
class AdapterFactory:
"""Factory for data loaders. This handles the fact that we pass in source(...) and value(...)
parameters to the data loaders."""

def __init__(self, loader_cls: Type[DataLoader], **kwargs: ParametrizedDependency):
"""Initializes a loader factory. This takes in parameterized dependencies
def __init__(self, adapter_cls: Type[AdapterCommon], **kwargs: ParametrizedDependency):
"""Initializes an adapter factory. This takes in parameterized dependencies
and stores them for later resolution.
Note that this is not strictly necessary -- we could easily put this in the
decorator, but I wanted to separate out/consolidate the logic between data savers and data
loaders.
:param loader_cls: Class of the loader to create.
:param adapter_cls: Class of the loader to create.
:param kwargs: Keyword arguments to pass to the loader, as parameterized dependencies.
"""
self.loader_cls = loader_cls
self.adapter_cls = adapter_cls
self.kwargs = kwargs
self.validate()

Expand All @@ -37,25 +43,62 @@ def validate(self):
:raises InvalidDecoratorException: If the arguments are invalid.
"""
required_args = self.loader_cls.get_required_arguments()
optional_args = self.loader_cls.get_optional_arguments()
required_args = self.adapter_cls.get_required_arguments()
optional_args = self.adapter_cls.get_optional_arguments()
missing_params = set(required_args.keys()) - set(self.kwargs.keys())
extra_params = (
set(self.kwargs.keys()) - set(required_args.keys()) - set(optional_args.keys())
)
if len(missing_params) > 0:
raise InvalidDecoratorException(
f"Missing required parameters for loader : {self.loader_cls}: {missing_params}. "
f"Missing required parameters for adapter : {self.adapter_cls}: {missing_params}. "
f"Required parameters/types are: {required_args}. Optional parameters/types are: "
f"{optional_args}. "
)
if len(extra_params) > 0:
raise InvalidDecoratorException(
f"Extra parameters for loader: {self.loader_cls} {extra_params}"
f"Extra parameters for loader: {self.adapter_cls} {extra_params}"
)

def create_loader(self, **resolved_kwargs: Any) -> DataLoader:
return self.loader_cls(**resolved_kwargs)
if not self.adapter_cls.can_load():
raise InvalidDecoratorException(f"Adapter {self.adapter_cls} cannot load data.")
return self.adapter_cls(**resolved_kwargs)

def create_saver(self, **resolved_kwargs: Any) -> DataSaver:
if not self.adapter_cls.can_load():
raise InvalidDecoratorException(f"Adapter {self.adapter_cls} cannot save data.")
return self.adapter_cls(**resolved_kwargs)


def resolve_kwargs(kwargs: Dict[str, Any]) -> Tuple[Dict[str, str], Dict[str, Any]]:
"""Resolves kwargs to a list of dependencies, and a dictionary of name
to resolved literal values.
:return: A tuple of the dependencies, and the resolved literal kwargs.
"""
dependencies = {}
resolved_kwargs = {}
for name, dependency in kwargs.items():
if isinstance(dependency, UpstreamDependency):
dependencies[name] = dependency.source
elif isinstance(dependency, LiteralDependency):
resolved_kwargs[name] = dependency.value
return dependencies, resolved_kwargs


def resolve_adapter_class(
type_: Type[Type], loader_classes: List[Type[AdapterCommon]]
) -> Type[AdapterCommon]:
"""Resolves the loader class for a function. This will return the most recently
registered loader class that applies to the injection type, hence the reversed order.
:param fn: Function to inject the loaded data into.
:return: The loader class to use.
"""
for loader_cls in reversed(loader_classes):
if loader_cls.applies_to(type_):
return loader_cls


class LoadFromDecorator(NodeCreator):
Expand All @@ -76,21 +119,6 @@ def __init__(
self.kwargs = kwargs
self.inject = inject_

def resolve_kwargs(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
"""Resolves kwargs to a list of dependencies, and a dictionary of name
to resolved literal values.
:return: A tuple of the dependencies, and the resolved literal kwargs.
"""
dependencies = {}
resolved_kwargs = {}
for name, dependency in self.kwargs.items():
if isinstance(dependency, UpstreamDependency):
dependencies[name] = dependency.source
elif isinstance(dependency, LiteralDependency):
resolved_kwargs[name] = dependency.value
return dependencies, resolved_kwargs

def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node]:
"""Generates two nodes:
1. A node that loads the data from the data source, and returns that + metadata
Expand All @@ -100,18 +128,27 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
:param config: The configuration to use.
:return: The resolved nodes
"""
loader_cls = self._resolve_loader_class(fn)
loader_factory = LoaderFactory(loader_cls, **self.kwargs)
inject_parameter, type_ = self._get_inject_parameter(fn)
loader_cls = resolve_adapter_class(
type_,
self.loader_classes,
)
if loader_cls is None:
raise InvalidDecoratorException(
f"No loader class found for type: {type_} specified by "
f"parameter: {inject_parameter} in function: {fn.__qualname__}"
)
loader_factory = AdapterFactory(loader_cls, **self.kwargs)
# dependencies is a map from param name -> source name
# we use this to pass the right arguments to the loader.
dependencies, resolved_kwargs = self.resolve_kwargs()
dependencies, resolved_kwargs = resolve_kwargs(self.kwargs)
# we need to invert the dependencies so that we can pass
# the right argument to the loader
dependencies_inverted = {v: k for k, v in dependencies.items()}
inject_parameter, load_type = self._get_inject_parameter(fn)

def load_data(
__loader_factory: LoaderFactory = loader_factory,
__loader_factory: AdapterFactory = loader_factory,
__load_type: Type[Type] = load_type,
__resolved_kwargs=resolved_kwargs,
__dependencies=dependencies_inverted,
Expand Down Expand Up @@ -207,42 +244,30 @@ def validate(self, fn: Callable):
:param fn:
:return:
"""
self._get_inject_parameter(fn)
cls = self._resolve_loader_class(fn)
loader_factory = LoaderFactory(cls, **self.kwargs)
inject_parameter, type_ = self._get_inject_parameter(fn)
cls = resolve_adapter_class(type_, self.loader_classes)
if cls is None:
raise InvalidDecoratorException(
f"No loader class found for type: {type_} specified by "
f"parameter: {inject_parameter} in function: {fn.__qualname__}"
)
loader_factory = AdapterFactory(cls, **self.kwargs)
loader_factory.validate()

def _resolve_loader_class(self, fn: Callable) -> Type[DataLoader]:
"""Resolves the loader class for a function. This will return the most recently
registered loader class that applies to the injection type, hence the reversed order.
:param fn: Function to inject the loaded data into.
:return: The loader class to use.
"""
param, type_ = self._get_inject_parameter(fn)
for loader_cls in reversed(self.loader_classes):
if loader_cls.applies_to(type_):
return loader_cls

raise InvalidDecoratorException(
f"No loader class found for type: {type_} specified by "
f"parameter: {param} in function: {fn.__qualname__}"
)


class load_from__meta__(type):
def __getattr__(cls, item: str):
if item in ADAPTER_REGISTRY:
return load_from.decorator_factory(ADAPTER_REGISTRY[item])
if item in LOADER_REGISTRY:
return load_from.decorator_factory(LOADER_REGISTRY[item])
try:
return super().__getattribute__(item)
except AttributeError as e:
raise AttributeError(
f"No loader named: {item} available for {cls.__name__}. "
f"Available loaders are: {ADAPTER_REGISTRY.keys()}. "
f"Available loaders are: {LOADER_REGISTRY.keys()}. "
f"If you've gotten to this point, you either (1) spelled the "
f"loader name wrong, (2) are trying to use a loader that does"
f"not exist (yet), or (3) have implemented it and "
f"not exist (yet)"
) from e


Expand Down Expand Up @@ -329,3 +354,129 @@ def create_decorator(
return LoadFromDecorator(__loaders, inject_=inject_, **kwargs)

return create_decorator


class save_to__meta__(type):
def __getattr__(cls, item: str):
if item in SAVER_REGISTRY:
return save_to.decorator_factory(SAVER_REGISTRY[item])
try:
return super().__getattribute__(item)
except AttributeError as e:
raise AttributeError(
f"No saver named: {item} available for {cls.__name__}. "
f"Available data savers are: {SAVER_REGISTRY.keys()}. "
f"If you've gotten to this point, you either (1) spelled the "
f"loader name wrong, (2) are trying to use a saver that does"
f"not exist (yet)"
) from e


class SaveToDecorator(SingleNodeNodeTransformer):
def __init__(
self,
saver_classes_: typing.Sequence[Type[DataSaver]],
artifact_name_: str = None,
**kwargs: ParametrizedDependency,
):
super().__init__(**kwargs)
self.artifact_name = artifact_name_
self.saver_classes = saver_classes_

def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
artifact_name = self.artifact_name
artifact_namespace = ()

if artifact_name is None:
artifact_name = node_.name
artifact_namespace = ("save",)

type_ = node_.type
saver_cls = resolve_adapter_class(
type_,
self.saver_classes,
)
if saver_cls is None:
raise InvalidDecoratorException(
f"No saver class found for type: {type_} specified by "
f"output type: {type_} in node: {node_.name} generated by "
f"function: {fn.__qualname__}."
)

adapter_factory = AdapterFactory(saver_cls, **self.kwargs)
dependencies, resolved_kwargs = resolve_kwargs(self.kwargs)
dependencies_inverted = {v: k for k, v in dependencies.items()}

def save_data(
__adapter_factory=adapter_factory,
__dependencies=dependencies_inverted,
__resolved_kwargs=resolved_kwargs,
__data_node_name=node_.name,
**input_kwargs,
) -> Dict[str, Any]:
input_args_with_fixed_dependencies = {
__dependencies.get(key, key): value for key, value in input_kwargs.items()
}
kwargs = {**__resolved_kwargs, **input_args_with_fixed_dependencies}
data_saver = __adapter_factory.create_saver(**kwargs)
data_to_save = kwargs[__data_node_name]
return data_saver.save_data(data_to_save)

def get_input_type_key(key: str) -> str:
return key if key not in dependencies else dependencies[key]

input_types = {
get_input_type_key(key): (Any, DependencyType.REQUIRED)
for key in saver_cls.get_required_arguments()
}
input_types.update(
{
(get_input_type_key(key) if key not in dependencies else dependencies[key]): (
Any,
DependencyType.OPTIONAL,
)
for key in saver_cls.get_optional_arguments()
}
)
# Take out all the resolved kwargs, as they are not dependencies, and will be filled out
# later
input_types = {
key: value for key, value in input_types.items() if key not in resolved_kwargs
}
input_types[node_.name] = (node_.type, DependencyType.REQUIRED)

save_node = node.Node(
name=artifact_name,
callabl=save_data,
typ=Dict[str, Any],
input_types=input_types,
namespace=artifact_namespace,
)
return [save_node, node_]

def validate(self, fn: Callable):
pass


class save_to(metaclass=load_from__meta__):
def __call__(self, *args, **kwargs):
return LoadFromDecorator(*args, **kwargs)

@classmethod
def decorator_factory(
cls, savers: typing.Sequence[Type[DataSaver]]
) -> Callable[..., SaveToDecorator]:
"""Effectively a partial function for the load_from decorator. Broken into its own (
rather than using functools.partial) as it is a little clearer to parse.
:param savers: Candidate data savers
:param loaders: Options of data loader classes to use
:return: The data loader decorator.
"""

def create_decorator(__savers=tuple(savers), **kwargs: ParametrizedDependency):
return SaveToDecorator(__savers, **kwargs)

return create_decorator
Loading

0 comments on commit dffb1f3

Please sign in to comment.