diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index 3eefabe9d..8f36df92b 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -605,7 +605,6 @@ def filter_config(config: Dict[str, Any], decorator: NodeTransformLifecycle) -> """ config_required = decorator.required_config() config_optional_with_defaults = decorator.optional_config() - print(decorator, config_required, config_optional_with_defaults) return resolve_config(decorator.name, config, config_required, config_optional_with_defaults) diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 22bb0087f..04d81946f 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -276,7 +276,9 @@ def validate(self, fn: Callable): try: for output_name, mappings in self.parameterization.items(): # TODO -- separate out into the two dependency-types - self.format_doc_string(fn.__doc__, output_name) + if output_name == self.PLACEHOLDER_PARAM_NAME: + output_name = fn.__name__ + self.format_doc_string(fn, output_name) except KeyError as e: raise base.InvalidDecoratorException( f"Function docstring templating is incorrect. " diff --git a/tests/function_modifiers/test_expanders.py b/tests/function_modifiers/test_expanders.py index ea888519d..1f80652c5 100644 --- a/tests/function_modifiers/test_expanders.py +++ b/tests/function_modifiers/test_expanders.py @@ -706,6 +706,28 @@ def foo(x: Dict) -> int: annotation.validate(foo) +def test_inject_validate_with_docstring(): + def foo(x: int) -> int: + """Docstring...""" + return x + + annotation = function_modifiers.inject(x=value(1)) + annotation.validate(foo) + (node_,) = annotation.expand_node(node.Node.from_fn(foo), {}, foo) + assert node_.documentation == "Docstring..." + + +def test_inject_validate_with_docstring_replacement(): + def foo(x: int) -> int: + """Docstring. x={x} is injected.""" + return x + + annotation = function_modifiers.inject(x=value(1)) + annotation.validate(foo) + (node_,) = annotation.expand_node(node.Node.from_fn(foo), {}, foo) + assert node_.documentation == "Docstring. x=1 is injected." + + def test_parameterize_repeated_sources(): def foo(x: int, y: int) -> int: return x + y