Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables inputs to be also outputs; Fixes #936 #947

Merged
merged 3 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,13 @@ def execute(
"""Basic executor for a function graph. Does no task-based execution, just does a DFS
and executes the graph in order, in memory."""
memoized_computation = dict() # memoized storage
nodes = [fg.nodes[node_name] for node_name in final_vars]
nodes = [fg.nodes[node_name] for node_name in final_vars if node_name in fg.nodes]
fg.execute(nodes, memoized_computation, overrides, inputs, run_id=run_id)
outputs = {
final_var: memoized_computation[final_var] for final_var in final_vars
# we do this here to enable inputs to also be used as outputs
# putting inputs into memoized before execution doesn't work due to some graphadapter assumptions.
final_var: memoized_computation.get(final_var, inputs.get(final_var))
for final_var in final_vars
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
} # only want request variables in df.
del memoized_computation # trying to cleanup some memory
return outputs
Expand Down
14 changes: 12 additions & 2 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,9 @@ def next_nodes_function(n: node.Node) -> List[node.Node]:
deps.append(dep)
return deps

return self.directional_dfs_traverse(next_nodes_function, starting_nodes=final_vars)
return self.directional_dfs_traverse(
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
next_nodes_function, starting_nodes=final_vars, runtime_inputs=runtime_inputs
)

def nodes_between(self, start: str, end: str) -> Set[node.Node]:
"""Given our function graph, and a list of desired output variables, returns the subgraph
Expand All @@ -1006,15 +1008,19 @@ def directional_dfs_traverse(
self,
next_nodes_fn: Callable[[node.Node], Collection[node.Node]],
starting_nodes: List[str],
runtime_inputs: Dict[str, Any] = None,
):
"""Traverses the DAG directionally using a DFS.

:param next_nodes_fn: Function to give the next set of nodes
:param starting_nodes: Which nodes to start at.
:param runtime_inputs: runtime inputs to the DAG. This is here to allow for inputs to be also outputs.
:return: a tuple of sets:
- set of all nodes.
- subset of nodes that human input is required for.
"""
if runtime_inputs is None:
runtime_inputs = {}
nodes = set()
user_nodes = set()

Expand All @@ -1029,7 +1035,11 @@ def dfs_traverse(node: node.Node):
missing_vars = []
for var in starting_nodes:
if var not in self.nodes and var not in self.config:
missing_vars.append(var)
# checking for runtime_inputs because it's not in the graph isn't really a graph concern. So perhaps we
# should move this outside of the graph in the future. This will do fine for now.
if var not in runtime_inputs:
# if it's not in the runtime inputs, it's a properly missing variable
missing_vars.append(var)
continue # collect all missing final variables
dfs_traverse(self.nodes[var])
if missing_vars:
Expand Down
96 changes: 96 additions & 0 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import importlib
import json
import math
import sys
from typing import Any, Callable, Dict, List, Type

import pandas as pd
import pytest

from hamilton import ad_hoc_utils, base, driver, settings
Expand All @@ -14,6 +16,7 @@

import tests.resources.data_quality
import tests.resources.dynamic_config
import tests.resources.example_module
import tests.resources.overrides
import tests.resources.test_for_materialization

Expand Down Expand Up @@ -448,3 +451,96 @@ def test_driver_validate_with_overrides_2():
.build()
)
assert dr.execute(["d"], overrides={"b": 1})["d"] == 3


def test_driver_extra_inputs_can_be_outputs():
"""Tests that we can request outputs that not in the graph, but are in the inputs."""
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.build()
)
actual = dr.execute(["d", "e"], inputs={"a": 1, "e": 10})
assert actual["d"] == 4
assert actual["e"] == 10
# validates validate functions
dr.validate_execution(["d", "e"], inputs={"a": 1, "e": 10})
dr.validate_materialization(additional_vars=["d", "e"], inputs={"a": 1, "e": 10})
# Checks dataframe use case
dr = (
driver.Builder()
.with_modules(tests.resources.example_module)
.with_adapter(base.PandasDataFrameResult())
.build()
)
actual = dr.execute(
["avg_3wk_spend", "e"],
inputs={"spend": pd.Series([1, 1, 1, 1, 1]), "e": pd.Series([10, 10, 10, 10, 10])},
)
pd.testing.assert_frame_equal(
actual,
pd.DataFrame(
{
"avg_3wk_spend": pd.Series([math.nan, math.nan, 1.0, 1.0, 1.0], dtype=float),
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
"e": [10, 10, 10, 10, 10],
}
),
)


def test_driver_v2_extra_inputs_can_be_outputs():
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
"""Tests that we can request outputs that not in the graph, but are in the inputs."""
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.enable_dynamic_execution(allow_experimental_mode=True)
.build()
)
actual = dr.execute(["d", "e"], inputs={"a": 1, "e": 10})
assert actual["d"] == 4
assert actual["e"] == 10


def test_driver_fails_on_outputs_not_in_input():
"""Tests that we fail correctly on outputs that are not in the inputs or the graph."""
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.build()
)
with pytest.raises(ValueError):
# f missing
dr.execute(["d", "f"], inputs={"a": 1})


def test_driver_v2_fails_on_outputs_not_in_input():
"""Tests that we fail correctly on outputs that are not in the inputs or the graph."""
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.enable_dynamic_execution(allow_experimental_mode=True)
.build()
)
with pytest.raises(ValueError):
# f missing
dr.execute(["d", "f"], inputs={"a": 1})


def test_driver_v2_inputs_can_be_none():
"""Tests that input can be None and checks will still work."""
dr = (
driver.Builder()
.with_modules(tests.resources.overrides)
.with_adapter(base.DefaultAdapter())
.build()
)
actual = dr.execute(["d"], inputs=None, overrides={"b": 1})
assert actual["d"] == 3

with pytest.raises(ValueError):
# validate that None doesn't cause issues
dr.execute(["e"], inputs=None)