diff --git a/hamilton/function_modifiers/macros.py b/hamilton/function_modifiers/macros.py index ec89bd74b..d855448cf 100644 --- a/hamilton/function_modifiers/macros.py +++ b/hamilton/function_modifiers/macros.py @@ -135,7 +135,7 @@ def test_function_signatures_compatible( dummy_param_values = { key: SENTINEL_ARG_VALUE for key, param_spec in fn_signature.parameters.items() - if param_spec.default != inspect.Parameter.empty + if param_spec.default is not inspect.Parameter.empty } # Then we update with the dummy values. Again, replacing doesn't matter (we'll be mimicking it later) dummy_param_values.update({key: SENTINEL_ARG_VALUE for key in fn_signature.parameters}) @@ -214,7 +214,7 @@ def wrapper_function(**kwargs): final_kwarg_values = { key: param_spec.default for key, param_spec in inspect.signature(fn).parameters.items() - if param_spec.default != inspect.Parameter.empty + if param_spec.default is not inspect.Parameter.empty } final_kwarg_values.update(kwargs) final_kwarg_values = does.map_kwargs(final_kwarg_values, self.argument_mapping) diff --git a/hamilton/graph.py b/hamilton/graph.py index 16f1c6534..c6a18e5e2 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -187,7 +187,7 @@ def _check_keyword_args_only(func: Callable) -> bool: """Checks if a function only takes keyword arguments.""" sig = inspect.signature(func) for param in sig.parameters.values(): - if param.default == inspect.Parameter.empty and param.kind not in [ + if param.default is inspect.Parameter.empty and param.kind not in [ inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD, ]: diff --git a/hamilton/node.py b/hamilton/node.py index e672f0499..4d4a4a764 100644 --- a/hamilton/node.py +++ b/hamilton/node.py @@ -36,7 +36,7 @@ class DependencyType(Enum): @staticmethod def from_parameter(param: inspect.Parameter): - if param.default == inspect.Parameter.empty: + if param.default is inspect.Parameter.empty: return DependencyType.REQUIRED return DependencyType.OPTIONAL diff --git a/tests/test_node.py b/tests/test_node.py index 87cc255b5..c2da8cde1 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -1,3 +1,4 @@ +import inspect import sys from typing import Any, Literal, TypeVar @@ -104,3 +105,15 @@ def annotated_func(first: ArrayN[np.float64], other: float = 2.0) -> ArrayN[np.f ) def test_tags_match_query(tags: dict, query: dict, expected: bool): assert matches_query(tags, query) == expected + + +def test_from_parameter_default_override_equals(): + class BrokenEquals: + def __eq__(self, other): + raise ValueError("I'm broken") + + def foo(b: BrokenEquals = BrokenEquals()): + pass + + param = DependencyType.from_parameter(inspect.signature(foo).parameters["b"]) + assert param == DependencyType.OPTIONAL