Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Replaces @parameterized variants with delegation to @parametrized_full
Browse files Browse the repository at this point in the history
We now no longer need bespoke logic. Thus we remove the body of all
variants, and replace with @parametrized_full. We can probably remove
a little validation as well...
  • Loading branch information
elijahbenizzy committed Jul 30, 2022
1 parent 9fd52dd commit 6b04f8a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 170 deletions.
220 changes: 52 additions & 168 deletions hamilton/function_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def get_dependency_type(self) -> ParametrizedDependencySource:

# def is_literal(self) -> bool:

def literal(val: Any) -> LiteralDependency:
return LiteralDependency(value=val)
def literal(value: Any) -> LiteralDependency:
return LiteralDependency(value=value)


def upstream(source: Any) -> UpstreamDependency:
Expand All @@ -98,31 +98,31 @@ def concat(upstream_parameter: str, literal_parameter: str) -> Any:
- a tuple of assignments (consisting of literals/upstream specifications), and docstring
- just assignments, in which case it parametrizes the existing docstring
"""
self.parametrization = parametrization
self.parametrization = {key: (value[0] if isinstance(value, tuple) else value) for key, value in parametrization.items()}
bad_values = []
for assigned_output, mapping in self.parametrization.items():
for parameter, value in mapping.items():
if not isinstance(value, ParametrizedDependency):
bad_values.append(value)
if bad_values:
raise InvalidDecoratorException(f'@parametrized_full must specify a dependency type -- either upstream() or literal().'
f'The following are not allowed: {bad_values}.')
self.specified_docstrings = {key: value[1] for key, value in parametrization.items() if isinstance(value, tuple)}

def expand_node(self, node_: node.Node, config: Dict[str, Any], fn: Callable) -> Collection[node.Node]:
nodes = []
for output_node, parametrization_with_optional_docstring in self.parametrization.items():
if isinstance(parametrization_with_optional_docstring, tuple): # In this case it contains the docstring
parametrization, docstring = parametrization_with_optional_docstring
parametrization, = parametrization_with_optional_docstring
else:
parametrization = parametrization_with_optional_docstring
docstring = None
docstring = self.format_doc_string(fn.__doc__, output_node)
upstream_dependencies = {
parameter: replacement for parameter, replacement in parametrization.items()
if replacement.get_dependency_type() == ParametrizedDependencySource.UPSTREAM}
literal_dependencies = {
parameter: replacement for parameter, replacement in parametrization.items()
if replacement.get_dependency_type() == ParametrizedDependencySource.LITERAL}
if docstring is None:
# then we have to generate
docstring = self.format_doc_string(
fn.__doc__,
output_node,
**{
**{key: value.source for key, value in upstream_dependencies.items()},
**{key: value.value for key, value in literal_dependencies.items()}
})

# Should we have a `literal` node/dependency type? Might be nicer than this -- E.G. we can see the inputs...
# Or we can create actual nodes that are just literals
Expand All @@ -139,15 +139,15 @@ def replacement_function(*args, upstream_dependencies=upstream_dependencies, lit
new_input_types = {}
for param, value in node_.input_types.items():
if param in upstream_dependencies:
new_input_types[upstream_dependencies[param].source] = value # We replace with the upstream_dependencies
new_input_types[upstream_dependencies[param].source] = value # We replace with the upstream_dependencies
elif param not in literal_dependencies:
new_input_types[param] = value # We just use the standard one, nothing is getting replaced
new_input_types[param] = value # We just use the standard one, nothing is getting replaced

nodes.append(
node.Node(
name=output_node,
typ=node_.type,
doc_string=docstring, # TODO -- change docstring
doc_string=docstring, # TODO -- change docstring
callabl=functools.partial(
replacement_function,
**{parameter: value.value for parameter, value in literal_dependencies.items()}),
Expand All @@ -159,9 +159,10 @@ def validate(self, fn: Callable):
signature = inspect.signature(fn)
func_param_names = set(signature.parameters.keys())
try:
# print(self.parametrization)
for output_name, mappings in self.parametrization.items():
# TODO -- separate out into the two dependency-types
self.format_doc_string(fn.__doc__, output_name, **mappings)
self.format_doc_string(fn.__doc__, output_name)
except KeyError as e:
raise InvalidDecoratorException(f'Function docstring templating is incorrect. '
f'Please fix up the docstring {fn.__module__}.{fn.__name__}.') from e
Expand All @@ -171,14 +172,14 @@ def validate(self, fn: Callable):
f'Error function {fn.__module__}.{fn.__name__} cannot have `{self.RESERVED_KWARG}` '
f'as a parameter it is reserved.')
missing_parameters = set()
for (parametrization, _) in self.parametrization.values():
for param_to_replace in parametrization.keys():
for mapping in self.parametrization.values():
for param_to_replace in mapping :
if param_to_replace not in func_param_names:
missing_parameters.add(param_to_replace)
if missing_parameters:
raise ValueError(f"Parametrization is invalid: the following parameters don't appear in the function itself: {', '.join(missing_parameters)}")
raise InvalidDecoratorException(f"Parametrization is invalid: the following parameters don't appear in the function itself: {', '.join(missing_parameters)}")

def format_doc_string(self, doc: str, output_name: str, **params: Dict[str, str]) -> str:
def format_doc_string(self, doc: str, output_name: str) -> str:
"""Helper function to format a function documentation string.
:param doc: the string template to format
Expand All @@ -187,58 +188,45 @@ def format_doc_string(self, doc: str, output_name: str, **params: Dict[str, str]
:return: formatted string
:raises: KeyError if there is a template variable missing from the parameter mapping.
"""

class IdentityDict(dict):
# quick hack to allow for formatting of missing parameters
def __missing__(self, key):
return key
return doc.format_map(IdentityDict(**{self.RESERVED_KWARG: output_name}, **params))


class parametrized(function_modifiers_base.NodeExpander):
if output_name in self.specified_docstrings:
return self.specified_docstrings[output_name]
if doc is None:
return None
parametrization = self.parametrization[output_name]
upstream_dependencies = {
parameter: replacement.source for parameter, replacement in parametrization.items()
if replacement.get_dependency_type() == ParametrizedDependencySource.UPSTREAM}
literal_dependencies = {
parameter: replacement.value for parameter, replacement in parametrization.items()
if replacement.get_dependency_type() == ParametrizedDependencySource.LITERAL}
return doc.format_map(
IdentityDict(
**{self.RESERVED_KWARG: output_name},
**{**upstream_dependencies, **literal_dependencies}))


class parametrized(parametrized_full):
def __init__(self, parameter: str, assigned_output: Dict[Tuple[str, str], Any]):
"""Constructor for a modifier that expands a single function into n, each of which
corresponds to a function in which the parameter value is replaced by that *specific value*.
:param parameter: Parameter to expand on.
:param assigned_output: A map of tuple of [parameter names, documentation] to values
"""
self.parameter = parameter
self.assigned_output = assigned_output
for node in assigned_output.keys():
if not isinstance(node, Tuple):
for node_ in assigned_output.keys():
if not isinstance(node_, Tuple):
raise InvalidDecoratorException(
f'assigned_output key is incorrect: {node}. The parameterized decorator needs a dict of '
f'assigned_output key is incorrect: {node_}. The parameterized decorator needs a dict of '
'[name, doc string] -> value to function.')

def validate(self, fn: Callable):
"""A function is invalid if it does not have the requested parameter.
:param fn: Function to validate against this annotation.
:raises: InvalidDecoratorException If the function does not have the requested parameter
"""
signature = inspect.signature(fn)
if self.parameter not in signature.parameters.keys():
raise InvalidDecoratorException(
f'Annotation is invalid -- no such parameter {self.parameter} in function {fn}')

def expand_node(self, node_: node.Node, config: Dict[str, Any], fn: Callable) -> Collection[node.Node]:
"""For each parameter value, loop through, partially curry the function, and output a node."""
input_types = node_.input_types
nodes = []
for (node_name, node_doc), value in self.assigned_output.items():
nodes.append(
node.Node(
node_name,
node_.type,
node_doc,
functools.partial(node_.callable, **{self.parameter: value}),
input_types={key: value for key, (value, _) in input_types.items() if key != self.parameter},
tags=node_.tags.copy()))
return nodes
super(parametrized, self).__init__(**{output: ({parameter: literal(value)}, documentation) for (output, documentation), value in assigned_output.items()})


class parametrized_input(function_modifiers_base.NodeExpander):

class parametrized_input(parametrized_full):
def __init__(self, parameter: str, variable_inputs: Dict[str, Tuple[str, str]]):
"""Constructor for a modifier that expands a single function into n, each of which
corresponds to the specified parameter replaced by a *specific input column*.
Expand All @@ -257,52 +245,16 @@ def __init__(self, parameter: str, variable_inputs: Dict[str, Tuple[str, str]]):
"""
logger.warning('`parameterized_input` (singular) is deprecated. It will be removed in a 2.0.0 release. '
'Please migrate to using `parameterized_inputs` (plural).')
self.parameter = parameter
self.assigned_output = variable_inputs
for value in variable_inputs.values():
if not isinstance(value, Tuple):
raise InvalidDecoratorException(
f'assigned_output key is incorrect: {node}. The parameterized decorator needs a dict of '
'input column -> [name, description] to function.')

def expand_node(self, node_: node.Node, config: Dict[str, Any], fn: Callable) -> Collection[node.Node]:
nodes = []
input_types = node_.input_types
for input_column, (node_name, node_description) in self.assigned_output.items():
specific_inputs = {key: value for key, (value, _) in input_types.items()}
specific_inputs[input_column] = specific_inputs.pop(self.parameter) # replace the name with the new function name so we get the right dependencies

def new_fn(*args, input_column=input_column, **kwargs):
"""This function rewrites what is passed in kwargs to the right kwarg for the function."""
kwargs = kwargs.copy()
kwargs[self.parameter] = kwargs.pop(input_column)
return node_.callable(*args, **kwargs)

nodes.append(
node.Node(
node_name,
node_.type,
node_description,
new_fn,
input_types=specific_inputs,
tags=node_.tags.copy()))
return nodes

def validate(self, fn: Callable):
"""A function is invalid if it does not have the requested parameter.
:param fn: Function to validate against this annotation.
:raises: InvalidDecoratorException If the function does not have the requested parameter
"""
signature = inspect.signature(fn)
if self.parameter not in signature.parameters.keys():
raise InvalidDecoratorException(
f'Annotation is invalid -- no such parameter {self.parameter} in function {fn}')
super(parametrized_input, self).__init__(
**{output: ({parameter: upstream(value)}, documentation) for value, (output, documentation) in variable_inputs.items()})


class parameterized_inputs(function_modifiers_base.NodeExpander):
RESERVED_KWARG = 'output_name'

class parameterized_inputs(parametrized_full):
def __init__(self, **parameterization: Dict[str, Dict[str, str]]):
"""Constructor for a modifier that expands a single function into n, each of which corresponds to replacing
some subset of the specified parameters with specific inputs.
Expand All @@ -327,76 +279,8 @@ def __init__(self, **parameterization: Dict[str, Dict[str, str]]):
for output, mappings in parameterization.items():
if not mappings:
raise ValueError(f'Error, {output} has a none/empty dictionary mapping. Please fill it.')

def expand_node(self, node_: node.Node, config: Dict[str, Any], fn: Callable) -> Collection[node.Node]:
nodes = []
input_types = node_.input_types
for output_name, mapping in self.parametrization.items():
node_name = output_name
# output_name is a reserved kwarg name.
node_description = self.format_doc_string(node_.documentation, output_name, **mapping)
specific_inputs = {key: value for key, (value, _) in input_types.items()}
for func_param, replacement_param in mapping.items():
logger.info(f'For function {node_.name}: mapping {replacement_param} to {func_param}.')
# replace the name with the new function name so we get the right dependencies
specific_inputs[replacement_param] = specific_inputs.pop(func_param)

def new_fn(*args, inputs=mapping, **kwargs):
"""This function rewrites what is passed in kwargs to the right kwarg for the function."""
kwargs = kwargs.copy()
for func_param, replacement_param in inputs.items():
kwargs[func_param] = kwargs.pop(replacement_param)
return node_.callable(*args, **kwargs)

nodes.append(
node.Node(
node_name,
node_.type,
node_description,
new_fn,
input_types=specific_inputs,
tags=node_.tags.copy()))
return nodes

def format_doc_string(self, doc: str, output_name: str, **params: Dict[str, str]) -> str:
"""Helper function to format a function documentation string.
:param doc: the string template to format
:param output_name: the output name of the function
:param params: the parameter mappings
:return: formatted string
:raises: KeyError if there is a template variable missing from the parameter mapping.
"""
return doc.format(**{self.RESERVED_KWARG: output_name}, **params)

def validate(self, fn: Callable):
"""A function is invalid if it does not have the requested parameter.
:param fn: Function to validate against this annotation.
:raises: InvalidDecoratorException If the function does not have the requested parameter
"""
signature = inspect.signature(fn)
func_param_name_set = set(signature.parameters.keys())
try:
for output_name, mappings in self.parametrization.items():
self.format_doc_string(fn.__doc__, output_name, **mappings)
except KeyError as e:
raise InvalidDecoratorException(f'Function docstring templating is incorrect. '
f'Please fix up the docstring {fn.__module__}.{fn.__name__}.') from e

if self.RESERVED_KWARG in func_param_name_set:
raise InvalidDecoratorException(
f'Error function {fn.__module__}.{fn.__name__} cannot have `{self.RESERVED_KWARG}` '
f'as a parameter it is reserved.')
missing_params = set()
for output_name, mappings in self.parametrization.items():
for func_name, replacement_name in mappings.items():
if func_name not in func_param_name_set:
missing_params.add(func_name)
if missing_params:
raise InvalidDecoratorException(
f'Annotation is invalid -- No such parameter(s) {missing_params} in function '
f'{fn.__module__}.{fn.__name__}.')
super(parameterized_inputs, self).__init__(
**{output: {parameter: upstream(source) for parameter, source in mapping.items()} for output, mapping in parameterization.items()})


class extract_columns(function_modifiers_base.NodeExpander):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_function_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,8 @@ def test_parametrized_full_multiple_replacements():
f'1. Add tests to ensure this works ✅ \n'
f'2. Get it to work without supplying docstrings and instead using parametrization ✅ \n'
f'2.5 Figure out what the API should look like if parameters are not supplied -- likely keep it so it can work without that ✅ \n'
f'3. Replace the other variants with this, ensure all use-cases are covered \n'
f'3. Replace the other variants with this, ensure all use-cases are covered \n'
f'4. Refactor to fix weird polymorphism or simplify in another way\n'
f'5. Refactor all decorators to be in a `function_modifiers` module if possible'
f'5. Refactor all decorators to be in a `function_modifiers` module if possible. Then have the __init__.py just copy them\n'
f'6. Add a @deprecate meta-decorator for deprecating decorators\n'
)

0 comments on commit 6b04f8a

Please sign in to comment.