Skip to content

Commit

Permalink
Fixes issue with @Inject and docstring
Browse files Browse the repository at this point in the history
Beforehand, we broke if there was a docstring in an inject decorated
function. This was due to having a placeholder function name, as it
delegated to `@parameterize`. The best fix would be to not have it
delegate (this is a bit of a hack), but for now we make it not break.
  • Loading branch information
elijahbenizzy committed Apr 23, 2023
1 parent f0ea42a commit 5f52ed0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
1 change: 0 additions & 1 deletion hamilton/function_modifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,6 @@ def filter_config(config: Dict[str, Any], decorator: NodeTransformLifecycle) ->
"""
config_required = decorator.required_config()
config_optional_with_defaults = decorator.optional_config()
print(decorator, config_required, config_optional_with_defaults)
return resolve_config(decorator.name, config, config_required, config_optional_with_defaults)


Expand Down
4 changes: 3 additions & 1 deletion hamilton/function_modifiers/expanders.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def validate(self, fn: Callable):
try:
for output_name, mappings in self.parameterization.items():
# TODO -- separate out into the two dependency-types
self.format_doc_string(fn.__doc__, output_name)
if output_name == self.PLACEHOLDER_PARAM_NAME:
output_name = fn.__name__
self.format_doc_string(fn, output_name)
except KeyError as e:
raise base.InvalidDecoratorException(
f"Function docstring templating is incorrect. "
Expand Down
22 changes: 22 additions & 0 deletions tests/function_modifiers/test_expanders.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,28 @@ def foo(x: Dict) -> int:
annotation.validate(foo)


def test_inject_validate_with_docstring():
def foo(x: int) -> int:
"""Docstring..."""
return x

annotation = function_modifiers.inject(x=value(1))
annotation.validate(foo)
(node_,) = annotation.expand_node(node.Node.from_fn(foo), {}, foo)
assert node_.documentation == "Docstring..."


def test_inject_validate_with_docstring_replacement():
def foo(x: int) -> int:
"""Docstring. x={x} is injected."""
return x

annotation = function_modifiers.inject(x=value(1))
annotation.validate(foo)
(node_,) = annotation.expand_node(node.Node.from_fn(foo), {}, foo)
assert node_.documentation == "Docstring. x=1 is injected."


def test_parameterize_repeated_sources():
def foo(x: int, y: int) -> int:
return x + y
Expand Down

0 comments on commit 5f52ed0

Please sign in to comment.