Skip to content

Commit

Permalink
Enables inputs to be also outputs; Fixes #936 (#947)
Browse files Browse the repository at this point in the history
* Enables inputs to be also outputs; Fixes #936

This enables one to pass in inputs and request them as outputs
independent of the graph.

The use case here is that you want to join some data at the end
that is extra and not in the DAG. E.g. extra pandas data.

Adds test cases to check for different angles.


* Adds two extra checks

* PR review feedback
  • Loading branch information
skrawcz authored Jun 12, 2024
1 parent 5a15d7c commit 4771dd0
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 4 deletions.
7 changes: 5 additions & 2 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,13 @@ def execute(
"""Basic executor for a function graph. Does no task-based execution, just does a DFS
and executes the graph in order, in memory."""
memoized_computation = dict() # memoized storage
nodes = [fg.nodes[node_name] for node_name in final_vars]
nodes = [fg.nodes[node_name] for node_name in final_vars if node_name in fg.nodes]
fg.execute(nodes, memoized_computation, overrides, inputs, run_id=run_id)
outputs = {
final_var: memoized_computation[final_var] for final_var in final_vars
# we do this here to enable inputs to also be used as outputs
# putting inputs into memoized before execution doesn't work due to some graphadapter assumptions.
final_var: memoized_computation.get(final_var, inputs.get(final_var))
for final_var in final_vars
} # only want request variables in df.
del memoized_computation # trying to cleanup some memory
return outputs
Expand Down
14 changes: 12 additions & 2 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,9 @@ def next_nodes_function(n: node.Node) -> List[node.Node]:
deps.append(dep)
return deps

return self.directional_dfs_traverse(next_nodes_function, starting_nodes=final_vars)
return self.directional_dfs_traverse(
next_nodes_function, starting_nodes=final_vars, runtime_inputs=runtime_inputs
)

def nodes_between(self, start: str, end: str) -> Set[node.Node]:
"""Given our function graph, and a list of desired output variables, returns the subgraph
Expand All @@ -1006,15 +1008,19 @@ def directional_dfs_traverse(
self,
next_nodes_fn: Callable[[node.Node], Collection[node.Node]],
starting_nodes: List[str],
runtime_inputs: Dict[str, Any] = None,
):
"""Traverses the DAG directionally using a DFS.
:param next_nodes_fn: Function to give the next set of nodes
:param starting_nodes: Which nodes to start at.
:param runtime_inputs: runtime inputs to the DAG. This is here to allow for inputs to be also outputs.
:return: a tuple of sets:
- set of all nodes.
- subset of nodes that human input is required for.
"""
if runtime_inputs is None:
runtime_inputs = {}
nodes = set()
user_nodes = set()

Expand All @@ -1029,7 +1035,11 @@ def dfs_traverse(node: node.Node):
missing_vars = []
for var in starting_nodes:
if var not in self.nodes and var not in self.config:
missing_vars.append(var)
# checking for runtime_inputs because it's not in the graph isn't really a graph concern. So perhaps we
# should move this outside of the graph in the future. This will do fine for now.
if var not in runtime_inputs:
# if it's not in the runtime inputs, it's a properly missing variable
missing_vars.append(var)
continue # collect all missing final variables
dfs_traverse(self.nodes[var])
if missing_vars:
Expand Down
96 changes: 96 additions & 0 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import importlib
import json
import math
import sys
from typing import Any, Callable, Dict, List, Type

import pandas as pd
import pytest

from hamilton import ad_hoc_utils, base, driver, settings
Expand All @@ -14,6 +16,7 @@

import tests.resources.data_quality
import tests.resources.dynamic_config
import tests.resources.example_module
import tests.resources.overrides
import tests.resources.test_for_materialization

Expand Down Expand Up @@ -448,3 +451,96 @@ def test_driver_validate_with_overrides_2():
.build()
)
assert dr.execute(["d"], overrides={"b": 1})["d"] == 3


def test_driver_extra_inputs_can_be_outputs():
"""Tests that we can request outputs that not in the graph, but are in the inputs."""
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.build()
)
actual = dr.execute(["d", "e"], inputs={"a": 1, "e": 10})
assert actual["d"] == 4
assert actual["e"] == 10
# validates validate functions
dr.validate_execution(["d", "e"], inputs={"a": 1, "e": 10})
dr.validate_materialization(additional_vars=["d", "e"], inputs={"a": 1, "e": 10})
# Checks dataframe use case
dr = (
driver.Builder()
.with_modules(tests.resources.example_module)
.with_adapter(base.PandasDataFrameResult())
.build()
)
actual = dr.execute(
["avg_3wk_spend", "e"],
inputs={"spend": pd.Series([1, 1, 1, 1, 1]), "e": pd.Series([10, 10, 10, 10, 10])},
)
pd.testing.assert_frame_equal(
actual,
pd.DataFrame(
{
"avg_3wk_spend": pd.Series([math.nan, math.nan, 1.0, 1.0, 1.0], dtype=float),
"e": [10, 10, 10, 10, 10],
}
),
)


def test_driver_v2_extra_inputs_can_be_outputs():
"""Tests that we can request outputs that not in the graph, but are in the inputs."""
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.enable_dynamic_execution(allow_experimental_mode=True)
.build()
)
actual = dr.execute(["d", "e"], inputs={"a": 1, "e": 10})
assert actual["d"] == 4
assert actual["e"] == 10


def test_driver_fails_on_outputs_not_in_input():
"""Tests that we fail correctly on outputs that are not in the inputs or the graph."""
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.build()
)
with pytest.raises(ValueError):
# f missing
dr.execute(["d", "f"], inputs={"a": 1})


def test_driver_v2_fails_on_outputs_not_in_input():
"""Tests that we fail correctly on outputs that are not in the inputs or the graph."""
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.enable_dynamic_execution(allow_experimental_mode=True)
.build()
)
with pytest.raises(ValueError):
# f missing
dr.execute(["d", "f"], inputs={"a": 1})


def test_driver_v2_inputs_can_be_none():
"""Tests that input can be None and checks will still work."""
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.build()
)
actual = dr.execute(["d"], inputs=None, overrides={"b": 1})
assert actual["d"] == 3

with pytest.raises(ValueError):
# validate that None doesn't cause issues
dr.execute(["e"], inputs=None)

0 comments on commit 4771dd0

Please sign in to comment.