Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Fixes minor bug with default arguments
Browse files Browse the repository at this point in the history
There was a bug where we would consider nodes that
were not part of the execution graph to see if all the inputs
to a node were acceptable.

E.g. if two functions required the same input, but one had default arguments,
and the function without a default was not on the execution path, we'd error.
Why? Because we'd naively look at all the node's dependents to see if one was
required, rather than restricting to the set of nodes required for execution.

Adds to unit test; aside our driver doesn't have the greatest unit test coverage :/
  • Loading branch information
skrawcz committed Jan 23, 2023
1 parent 1ecc1e2 commit 1fdf6ec
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
21 changes: 15 additions & 6 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from types import ModuleType
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple
from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple

import pandas as pd

Expand Down Expand Up @@ -143,39 +143,48 @@ def capture_constructor_telemetry(
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Error caught in processing telemetry: {e}")

def _node_is_required_by_anything(self, node_: node.Node) -> bool:
def _node_is_required_by_anything(self, node_: node.Node, node_set: Set[node.Node]) -> bool:
"""Checks dependencies on this node and determines if at least one requires it.
Nodes can be optionally depended upon, i.e. the function parameter has a default value. We want to check that
of the nodes the depend on this one, at least one of them requires it, i.e. the parameter is not optional.
:param node_: node in question
:param node_set: checks that we traverse only nodes in the provided set.
:return: True if it is required by any downstream node, false otherwise
"""
required = False
for downstream_node in node_.depended_on_by:
if downstream_node not in node_set:
continue
_, dep_type = downstream_node.input_types[node_.name]
if dep_type == node.DependencyType.REQUIRED:
return True
return required

def validate_inputs(
self, user_nodes: Collection[node.Node], inputs: typing.Optional[Dict[str, Any]] = None
self,
user_nodes: Collection[node.Node],
inputs: typing.Optional[Dict[str, Any]] = None,
nodes_set: Collection[node.Node] = None,
):
"""Validates that inputs meet our expectations. This means that:
1. The runtime inputs don't clash with the graph's config
2. All expected graph inputs are provided, either in config or at runtime
:param user_nodes: The required nodes we need for computation.
:param inputs: the user inputs provided.
:param nodes_set: the set of nodes to use for validation; Optional.
"""
if inputs is None:
inputs = {}
if nodes_set is None:
nodes_set = set(self.graph.nodes.values())
(all_inputs,) = (graph.FunctionGraph.combine_config_and_inputs(self.graph.config, inputs),)
errors = []
for user_node in user_nodes:
if user_node.name not in all_inputs:
if self._node_is_required_by_anything(user_node):
if self._node_is_required_by_anything(user_node, nodes_set):
errors.append(
f"Error: Required input {user_node.name} not provided "
f"for nodes: {[node.name for node in user_node.depended_on_by]}."
Expand Down Expand Up @@ -291,7 +300,7 @@ def raw_execute(
"""
nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs)
self.validate_inputs(
user_nodes, inputs
user_nodes, inputs, nodes
) # TODO -- validate within the function graph itself
if display_graph: # deprecated flow.
logger.warning(
Expand Down Expand Up @@ -362,7 +371,7 @@ def visualize_execution(
See https://graphviz.org/doc/info/attrs.html for options.
"""
nodes, user_nodes = self.graph.get_upstream_nodes(final_vars, inputs)
self.validate_inputs(user_nodes, inputs)
self.validate_inputs(user_nodes, inputs, nodes)
try:
self.graph.display(
nodes,
Expand Down
5 changes: 5 additions & 0 deletions tests/resources/test_default_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@ def B(A: int) -> int:

def C(A: int) -> int: # empty string doc on purpose.
return A * 2


def D(defaults_to_zero: int) -> int:
"""This requires the default value."""
return defaults_to_zero * 2
10 changes: 8 additions & 2 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,14 +436,20 @@ def test_end_to_end_with_config_modifier():
def test_non_required_nodes():
fg = graph.FunctionGraph(tests.resources.test_default_args, config={"required": 10})
results = fg.execute(
[n for n in fg.get_nodes() if n.node_source == NodeSource.STANDARD], {}, {}
# D is not on the execution path, so it should not be break things
[n for n in fg.get_nodes() if n.node_source == NodeSource.STANDARD and n.name != "D"],
{},
{},
)
assert results["A"] == 10
fg = graph.FunctionGraph(
tests.resources.test_default_args, config={"required": 10, "defaults_to_zero": 1}
)
results = fg.execute(
[n for n in fg.get_nodes() if n.node_source == NodeSource.STANDARD], {}, {}
# D is not on the execution path, so it should not be break things
[n for n in fg.get_nodes() if n.node_source == NodeSource.STANDARD and n.name != "D"],
{},
{},
)
assert results["A"] == 11

Expand Down

0 comments on commit 1fdf6ec

Please sign in to comment.