Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jernejfrank committed Oct 30, 2024
1 parent a28706e commit 0c41502
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 3 deletions.
267 changes: 264 additions & 3 deletions tests/function_modifiers/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@
import random
from typing import Tuple

import pandas as pd
import pytest

from hamilton import ad_hoc_utils, graph
from hamilton import ad_hoc_utils, driver, graph, node
from hamilton.function_modifiers import (
InvalidDecoratorException,
config,
parameterized_subdag,
recursive,
subdag,
value,
with_columns,
)
from hamilton.function_modifiers.base import NodeTransformer
from hamilton.function_modifiers.base import NodeInjector, NodeTransformer
from hamilton.function_modifiers.dependencies import source
from hamilton.function_modifiers.recursive import _validate_config_inputs

import tests.resources.reuse_subdag
import tests.resources.with_columns


def test_collect_function_fns():
Expand Down Expand Up @@ -181,7 +184,7 @@ def test_reuse_subdag_end_to_end():
fg = graph.FunctionGraph.from_modules(tests.resources.reuse_subdag, config={"op": "subtract"})
prefixless_nodes = []
prefixed_nodes = collections.defaultdict(list)
for name, node in fg.nodes.items():
for name, node in fg.nodes.items(): # noqa:F402
name_split = name.split(".")
if len(name_split) == 1:
prefixless_nodes.append(node)
Expand Down Expand Up @@ -539,3 +542,261 @@ def test_recursive_validate_config_inputs_happy(config, inputs):
def test_recursive_validate_config_inputs_sad(config, inputs):
with pytest.raises(InvalidDecoratorException):
_validate_config_inputs(config, inputs)


def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series:
return col_1 + 100


def test_detect_duplicate_nodes():
node_a = node.Node.from_fn(dummy_fn_with_columns, name="a")
node_b = node.Node.from_fn(dummy_fn_with_columns, name="a")
node_c = node.Node.from_fn(dummy_fn_with_columns, name="c")

if not with_columns._check_for_duplicates([node_a, node_b, node_c]):
raise (AssertionError)

if with_columns._check_for_duplicates([node_a, node_c]):
raise (AssertionError)


def test_select_not_empty():
error_message = "Please specify at least one column to append or update."

with pytest.raises(ValueError) as e:
with_columns(dummy_fn_with_columns)
assert str(e.value) == error_message


def test_columns_to_pass_and_pass_dataframe_as_raises_error():
error_message = (
"You must specify only one of columns_to_pass and "
"pass_dataframe_as. "
"This is because specifying pass_dataframe_as injects into "
"the set of columns, allowing you to perform your own extraction"
"from the dataframe. We then execute all columns in the sbudag"
"in order, passing in that initial dataframe. If you want"
"to reference columns in your code, you'll have to specify "
"the set of initial columns, and allow the subdag decorator "
"to inject the dataframe through. The initial columns tell "
"us which parameters to take from that dataframe, so we can"
"feed the right data into the right columns."
)

with pytest.raises(ValueError) as e:
with_columns(
dummy_fn_with_columns, columns_to_pass=["a"], pass_dataframe_as="a", select="a"
)
assert str(e.value) == error_message


def test_create_column_nodes_pass_dataframe():
def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
return upstream_df

dummy_node = node.Node.from_fn(target_fn)

decorator = with_columns(
dummy_fn_with_columns, pass_dataframe_as="upstream_df", select="dummy_fn_with_columns"
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
inject_parameter, initial_nodes = decorator._get_inital_nodes(
fn=target_fn, params=injectable_params
)

assert inject_parameter == "upstream_df"
assert len(initial_nodes) == 0


def test_create_column_nodes_extract_single_columns():
def dummy_df() -> pd.DataFrame:
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})

def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
return upstream_df

dummy_node = node.Node.from_fn(target_fn)

decorator = with_columns(
dummy_fn_with_columns, columns_to_pass="col_1", select="dummy_fn_with_columns"
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])

inject_parameter, initial_nodes = decorator._get_inital_nodes(
fn=target_fn, params=injectable_params
)

assert inject_parameter == "upstream_df"
assert len(initial_nodes) == 1
assert initial_nodes[0].name == "col_1"
assert initial_nodes[0].type == pd.Series
pd.testing.assert_series_equal(
initial_nodes[0].callable(upstream_df=dummy_df()),
pd.Series([1, 2, 3, 4]),
check_names=False,
)


def test_create_column_nodes_extract_multiple_columns():
def dummy_df() -> pd.DataFrame:
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})

def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
return upstream_df

dummy_node = node.Node.from_fn(target_fn)

decorator = with_columns(
dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select="dummy_fn_with_columns"
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])

inject_parameter, initial_nodes = decorator._get_inital_nodes(
fn=target_fn, params=injectable_params
)

assert inject_parameter == "upstream_df"
assert len(initial_nodes) == 2
assert initial_nodes[0].name == "col_1"
assert initial_nodes[1].name == "col_2"
assert initial_nodes[0].type == pd.Series
assert initial_nodes[1].type == pd.Series
pd.testing.assert_series_equal(
initial_nodes[0].callable(upstream_df=dummy_df()),
pd.Series([1, 2, 3, 4]),
check_names=False,
)
pd.testing.assert_series_equal(
initial_nodes[1].callable(upstream_df=dummy_df()),
pd.Series([11, 12, 13, 14]),
check_names=False,
)


def test_no_matching_select_column_error():
def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
return upstream_df

dummy_node = node.Node.from_fn(target_fn)
select = "wrong_column"

decorator = with_columns(
dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=select
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])

error_message = (
f"No nodes found upstream from select columns: {[select]} for function: "
f"{target_fn.__qualname__}"
)
with pytest.raises(ValueError) as e:
decorator.inject_nodes(injectable_params, {}, fn=target_fn)

assert str(e.value) == error_message


def test_append_into_original_df():
def dummy_df() -> pd.DataFrame:
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})

def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
return upstream_df

decorator = with_columns(
dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select="dummy_fn_with_columns"
)
merge_node = decorator._create_merge_node(upstream_node="upstream_df", node_name="merge_node")

output_df = merge_node.callable(
upstream_df=dummy_df(),
dummy_fn_with_columns=dummy_fn_with_columns(col_1=pd.Series([1, 2, 3, 4])),
)
assert merge_node.name == "merge_node"
assert merge_node.type == pd.DataFrame

pd.testing.assert_series_equal(output_df["col_1"], pd.Series([1, 2, 3, 4]), check_names=False)
pd.testing.assert_series_equal(
output_df["col_2"], pd.Series([11, 12, 13, 14]), check_names=False
)
pd.testing.assert_series_equal(
output_df["dummy_fn_with_columns"], pd.Series([101, 102, 103, 104]), check_names=False
)


def test_override_original_column_in_df():
def dummy_df() -> pd.DataFrame:
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})

def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
return upstream_df

def col_1() -> pd.Series:
return pd.Series([0, 3, 5, 7])

decorator = with_columns(col_1, pass_dataframe_as="upstream_df", select="col_1")
merge_node = decorator._create_merge_node(upstream_node="upstream_df", node_name="merge_node")

output_df = merge_node.callable(upstream_df=dummy_df(), col_1=col_1())
assert merge_node.name == "merge_node"
assert merge_node.type == pd.DataFrame

pd.testing.assert_series_equal(output_df["col_1"], pd.Series([0, 3, 5, 7]), check_names=False)
pd.testing.assert_series_equal(
output_df["col_2"], pd.Series([11, 12, 13, 14]), check_names=False
)


def test_assign_custom_namespace_with_columns():
def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
return upstream_df

dummy_node = node.Node.from_fn(target_fn)
decorator = with_columns(
dummy_fn_with_columns,
columns_to_pass=["col_1", "col_2"],
select="dummy_fn_with_columns",
namespace="dummy_namespace",
)
nodes_ = decorator.transform_dag([dummy_node], {}, target_fn)

assert nodes_[0].name == "target_fn"
assert nodes_[1].name == "dummy_namespace.col_1"
assert nodes_[2].name == "dummy_namespace.col_2"
assert nodes_[3].name == "dummy_namespace.dummy_fn_with_columns"
assert nodes_[4].name == "dummy_namespace.__append"


def test_end_to_end_with_columns_automatic_extract():
dr = driver.Builder().with_modules(tests.resources.with_columns).build()
result = dr.execute(final_vars=["final_df"])["final_df"]

expected_df = pd.DataFrame(
{
"col_1": [1, 2, 3, 4],
"col_2": [11, 12, 13, 14],
"col_3": [1, 1, 1, 1],
"substract_1_from_2": [10, 10, 10, 10],
"multiply_3_by_5": [5, 5, 5, 5],
}
)
pd.testing.assert_frame_equal(result, expected_df)


def test_end_to_end_with_columns_pass_dataframe():
dr = (
driver.Builder()
.with_modules(tests.resources.with_columns)
.with_config({"case": "override_columns"})
.build()
)

result = dr.execute(final_vars=["final_df_2"])["final_df_2"]
expected_df = pd.DataFrame(
{
"col_1": [1, 2, 3, 4],
"col_2": [11, 12, 13, 14],
"col_3": [0, 2, 4, 6],
"multiply_3_by_5": [0, 10, 20, 30],
}
)
pd.testing.assert_frame_equal(result, expected_df)
39 changes: 39 additions & 0 deletions tests/resources/with_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pandas as pd

from hamilton.function_modifiers import with_columns


def initial_df() -> pd.DataFrame:
return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14], "col_3": [1, 1, 1, 1]})


def substract_1_from_2(col_1: pd.Series, col_2: pd.Series) -> pd.Series:
return col_2 - col_1


def multiply_3_by_5(col_3: pd.Series) -> pd.Series:
return col_3 * 5


@with_columns(
substract_1_from_2,
multiply_3_by_5,
columns_to_pass=["col_1", "col_2", "col_3"],
select=["substract_1_from_2", "multiply_3_by_5"],
namespace="some_subdag",
)
def final_df(initial_df: pd.DataFrame) -> pd.DataFrame:
return initial_df


def col_3(initial_df: pd.DataFrame) -> pd.Series:
return pd.Series([0, 2, 4, 6])


@with_columns(
col_3,
pass_dataframe_as="initial_df",
select=["col_3", "multiply_3_by_5"],
)
def final_df_2(initial_df: pd.DataFrame) -> pd.DataFrame:
return initial_df

0 comments on commit 0c41502

Please sign in to comment.