From 28c9d1473a3b2d4a2bafe8aea2f834a9e064d8d1 Mon Sep 17 00:00:00 2001 From: jernejfrank Date: Sat, 16 Nov 2024 07:48:35 +0800 Subject: [PATCH] Add e2e tests for async pipe_family Additional integration tests for async pipe_input, pipe_output and mutate. --- tests/function_modifiers/test_macros.py | 62 ++++++++++++++++++++++++- tests/resources/mutate_async.py | 21 +++++++++ tests/resources/pipe_async.py | 35 ++++++++++++++ 3 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 tests/resources/mutate_async.py create mode 100644 tests/resources/pipe_async.py diff --git a/tests/function_modifiers/test_macros.py b/tests/function_modifiers/test_macros.py index 64b53bf8e..dc0963203 100644 --- a/tests/function_modifiers/test_macros.py +++ b/tests/function_modifiers/test_macros.py @@ -5,7 +5,7 @@ import pytest import hamilton.function_modifiers -from hamilton import base, driver, function_modifiers, models, node +from hamilton import async_driver, base, driver, function_modifiers, models, node from hamilton.function_modifiers import does from hamilton.function_modifiers.dependencies import source, value from hamilton.function_modifiers.macros import ( @@ -20,6 +20,8 @@ from hamilton.node import DependencyType import tests.resources.mutate +import tests.resources.mutate_async +import tests.resources.pipe_async import tests.resources.pipe_input import tests.resources.pipe_output @@ -1150,3 +1152,61 @@ def test_mutate_end_to_end_1(import_mutate_module): ) assert result["chain_1_using_mutate"] == result["chain_1_not_using_mutate"] assert result["chain_2_using_mutate"] == result["chain_2_not_using_mutate"] + + +@pytest.mark.asyncio +async def test_async_pipe_input_and_output_end_to_end(): + inputs = {"data_input": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})} + + group_by_a = inputs["data_input"].groupby("a").sum().reset_index() + group_by_b = inputs["data_input"].groupby("b").sum().reset_index() + + dr = ( + await async_driver.Builder() + .with_modules(tests.resources.pipe_async) + .with_config(dict(groupby="a")) + .build() + ) + results = await dr.execute(final_vars=["data_pipe_input", "data_pipe_output"], inputs=inputs) + + pd.testing.assert_frame_equal(group_by_a, results["data_pipe_output"]) + pd.testing.assert_frame_equal(group_by_a, results["data_pipe_input"]) + + dr = ( + await async_driver.Builder() + .with_modules(tests.resources.pipe_async) + .with_config(dict(groupby="b")) + .build() + ) + results = await dr.execute(final_vars=["data_pipe_input", "data_pipe_output"], inputs=inputs) + + pd.testing.assert_frame_equal(group_by_b, results["data_pipe_output"]) + pd.testing.assert_frame_equal(group_by_b, results["data_pipe_input"]) + + +@pytest.mark.asyncio +async def test_async_mutate_end_to_end(): + inputs = {"data_input": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})} + + group_by_a = inputs["data_input"].groupby("a").sum().reset_index() + group_by_b = inputs["data_input"].groupby("b").sum().reset_index() + + dr = ( + await async_driver.Builder() + .with_modules(tests.resources.mutate_async) + .with_config(dict(groupby="a")) + .build() + ) + results = await dr.execute(final_vars=["data_mutate"], inputs=inputs) + + pd.testing.assert_frame_equal(group_by_a, results["data_mutate"]) + + dr = ( + await async_driver.Builder() + .with_modules(tests.resources.mutate_async) + .with_config(dict(groupby="b")) + .build() + ) + results = await dr.execute(final_vars=["data_mutate"], inputs=inputs) + + pd.testing.assert_frame_equal(group_by_b, results["data_mutate"]) diff --git a/tests/resources/mutate_async.py b/tests/resources/mutate_async.py new file mode 100644 index 000000000..0bfc2b38e --- /dev/null +++ b/tests/resources/mutate_async.py @@ -0,0 +1,21 @@ +import asyncio + +import pandas as pd + +from hamilton.function_modifiers import apply_to, mutate + + +def data_mutate(data_input: pd.DataFrame) -> pd.DataFrame: + return data_input + + +@mutate(apply_to(data_mutate).when(groupby="a")) +async def _groupby_a_mutate(d: pd.DataFrame) -> pd.DataFrame: + await asyncio.sleep(0.0001) + return d.groupby("a").sum().reset_index() + + +@mutate(apply_to(data_mutate).when_not(groupby="a")) +async def _groupby_b_mutate(d: pd.DataFrame) -> pd.DataFrame: + await asyncio.sleep(0.0001) + return d.groupby("b").sum().reset_index() diff --git a/tests/resources/pipe_async.py b/tests/resources/pipe_async.py new file mode 100644 index 000000000..826f73107 --- /dev/null +++ b/tests/resources/pipe_async.py @@ -0,0 +1,35 @@ +import asyncio + +import pandas as pd + +from hamilton.function_modifiers import pipe_input, pipe_output, step + +# async def data_input() -> pd.DataFrame: +# await asyncio.sleep(0.0001) +# return + + +async def _groupby_a(d: pd.DataFrame) -> pd.DataFrame: + await asyncio.sleep(0.0001) + return d.groupby("a").sum().reset_index() + + +async def _groupby_b(d: pd.DataFrame) -> pd.DataFrame: + await asyncio.sleep(0.0001) + return d.groupby("b").sum().reset_index() + + +@pipe_input( + step(_groupby_a).when(groupby="a"), + step(_groupby_b).when_not(groupby="a"), +) +def data_pipe_input(data_input: pd.DataFrame) -> pd.DataFrame: + return data_input + + +@pipe_output( + step(_groupby_a).when(groupby="a"), + step(_groupby_b).when_not(groupby="a"), +) +def data_pipe_output(data_input: pd.DataFrame) -> pd.DataFrame: + return data_input