Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds FunctionInputOutputTypeChecker #757

Merged
merged 5 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified examples/pandas/split-apply-combine/my_full_dag.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 2 additions & 8 deletions examples/pandas/split-apply-combine/my_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,18 @@ def _tax_credit(df: DataFrame, tax_credits: Dict[str, float]) -> DataFrame:

@extract_fields({"under_100k": DataFrame, "over_100k": DataFrame})
# Step 1: DataFrame is split in 2 DataFrames
def split_dataframe(
input: DataFrame, tax_rates: Dict[str, float], tax_credits: Dict[str, float]
) -> Dict[str, DataFrame]:
def split_dataframe(input: DataFrame) -> Dict[str, DataFrame]:
"""
That function takes the DataFrame in input and split it in 2 DataFrames:
- under_100k: Rows where 'Income' is under 100k
- over_100k: Rows where 'Income' is over 100k

:param input: the DataFrame to process
:param tax_rates: The Tax Rates rules
:param tax_credits: The Tax Credits rules
:return: a Dict with the DataFrames and the Tax Rates & Credit rules
"""
return {
"under_100k": input.query("Income < 100000"),
"over_100k": input.query("Income >= 100000"),
"tax_rates": tax_rates,
"tax_credits": tax_credits,
}


Expand All @@ -84,7 +78,7 @@ def split_dataframe(
def under_100k_tax(under_100k: DataFrame) -> DataFrame:
"""
Tax calculation pipeline for 'Income' under 100k.
:param df: The DataFrame where 'Income' is under 100k
:param under_100k: The DataFrame where 'Income' is under 100k
:return: the DataFrame with the 'Tax' Series
"""
return under_100k
Expand Down
2 changes: 1 addition & 1 deletion examples/pandas/split-apply-combine/my_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def read_table(table: str, delimiter="|") -> DataFrame:
}

tax_credits = {
"Children == 0": 0, # 0 child: Tax credit 0 %
"Children == 0": 0.0, # 0 child: Tax credit 0 %
"Children == 1": 0.02, # 1 child: Tax credit 2 %
"Children == 2": 0.04, # 2 children: Tax credit 4 %
"Children == 3": 0.06, # 3 children: Tax credit 6 %
Expand Down
10 changes: 8 additions & 2 deletions examples/pandas/split-apply-combine/my_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import my_functions
from pandas import DataFrame

from hamilton import driver
from hamilton import base, driver, lifecycle

driver = driver.Driver({}, my_functions)
driver = (
driver.Builder()
.with_config({})
.with_modules(my_functions)
.with_adapters(lifecycle.FunctionInputOutputTypeChecker(), base.PandasDataFrameResult())
.build()
)


class TaxCalculator:
Expand Down
100 changes: 60 additions & 40 deletions hamilton/execution/graph_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,32 @@ def create_input_string(kwargs: dict) -> str:
return input_string


def create_error_message(kwargs: dict, node_: node.Node, step: str) -> str:
"""Creates an error message for a node that errored."""
# This code is coupled to how @config resolution works. Ideally it shouldn't be,
# so when @config resolvers are changed to return Nodes, then fn.__name__ should
# just work.
original_func_name = "unknown"
if node_.originating_functions:
if hasattr(node_.originating_functions[0], "__original_name__"):
original_func_name = node_.originating_functions[0].__original_name__
else:
original_func_name = node_.originating_functions[0].__name__
module = (
node_.originating_functions[0].__module__
if node_.originating_functions and hasattr(node_.originating_functions[0], "__module__")
else "unknown_module"
)
message = f">{step} {node_.name} [{module}.{original_func_name}()] encountered an error"
padding = " " * (80 - min(len(message), 79) - 1)
message += padding + "<"
input_string = create_input_string(kwargs)
message += "\n> Node inputs:\n" + input_string
border = "*" * 80
message = "\n" + border + "\n" + message + "\n" + border
return message


def execute_subdag(
nodes: Collection[node.Node],
inputs: Dict[str, Any],
Expand Down Expand Up @@ -178,15 +204,21 @@ def dfs_traverse(
error = None
result = None
success = True
pre_node_execute_errored = False
try:
if adapter.does_hook("pre_node_execute", is_async=False):
adapter.call_all_lifecycle_hooks_sync(
"pre_node_execute",
run_id=run_id,
node_=node_,
kwargs=kwargs,
task_id=task_id,
)
try:
adapter.call_all_lifecycle_hooks_sync(
"pre_node_execute",
run_id=run_id,
node_=node_,
kwargs=kwargs,
task_id=task_id,
)
except Exception as e:
pre_node_execute_errored = True
raise e

if adapter.does_method("do_node_execute", is_async=False):
result = adapter.call_lifecycle_method_sync(
"do_node_execute",
Expand All @@ -200,41 +232,29 @@ def dfs_traverse(
except Exception as e:
success = False
error = e
# This code is coupled to how @config resolution works. Ideally it shouldn't be,
# so when @config resolvers are changed to return Nodes, then fn.__name__ should
# just work.
original_func_name = "unknown"
if node_.originating_functions:
if hasattr(node_.originating_functions[0], "__original_name__"):
original_func_name = node_.originating_functions[0].__original_name__
else:
original_func_name = node_.originating_functions[0].__name__
module = (
node_.originating_functions[0].__module__
if node_.originating_functions
and hasattr(node_.originating_functions[0], "__module__")
else "unknown_module"
)
message = f"> {node_.name} [{module}.{original_func_name}()] encountered an error"
padding = " " * (80 - len(message) - 1)
message += padding + "<"
input_string = create_input_string(kwargs)
message += "\n> Node inputs:\n" + input_string
border = "*" * 80
logger.exception("\n" + border + "\n" + message + "\n" + border)
step = "[pre-node-execute]" if pre_node_execute_errored else ""
message = create_error_message(kwargs, node_, step)
logger.exception(message)
raise
finally:
if adapter.does_hook("post_node_execute", is_async=False):
adapter.call_all_lifecycle_hooks_sync(
"post_node_execute",
run_id=run_id,
node_=node_,
kwargs=kwargs,
success=success,
error=error,
result=result,
task_id=task_id,
)
if not pre_node_execute_errored and adapter.does_hook(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm we should really have a context manager that does this... But this works for now.

"post_node_execute", is_async=False
):
try:
adapter.call_all_lifecycle_hooks_sync(
"post_node_execute",
run_id=run_id,
node_=node_,
kwargs=kwargs,
success=success,
error=error,
result=result,
task_id=task_id,
)
except Exception:
message = create_error_message(kwargs, node_, "[post-node-execute]")
logger.exception(message)
raise

computed[node_.name] = result
# > pruning the graph
Expand Down
8 changes: 7 additions & 1 deletion hamilton/lifecycle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
TaskExecutionHook,
)
from .base import LifecycleAdapter # noqa: F401
from .default import PDBDebugger, PrintLn, SlowDownYouMoveTooFast # noqa: F401
from .default import ( # noqa: F401
FunctionInputOutputTypeChecker,
PDBDebugger,
PrintLn,
SlowDownYouMoveTooFast,
)

PrintLnHook = PrintLn # for backwards compatibility -- this will be removed in 2.0

Expand All @@ -31,4 +36,5 @@
"NodeExecutionMethod",
"StaticValidator",
"TaskExecutionHook",
"FunctionInputOutputTypeChecker",
]
3 changes: 3 additions & 0 deletions hamilton/lifecycle/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def run_before_node_execution(
node_return_type: type,
task_id: Optional[str],
run_id: str,
node_input_types: Dict[str, Any],
**future_kwargs: Any,
):
"""Hook that is executed prior to node execution.
Expand All @@ -184,6 +185,7 @@ def run_before_node_execution(
:param node_return_type: Return type of the node
:param task_id: The ID of the task, none if not in a task-based environment
:param run_id: Run ID (unique in process scope) of the current run. Use this to track state.
:param node_input_types: the input types to the node and what it is expecting
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility
"""
pass
Expand All @@ -206,6 +208,7 @@ def pre_node_execute(
node_return_type=node_.type,
task_id=task_id,
run_id=run_id,
node_input_types={k: v[0] for k, v in node_.input_types.items()},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add TODO for integrating the HamiltonNode class

)

@abc.abstractmethod
Expand Down
123 changes: 122 additions & 1 deletion hamilton/lifecycle/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
import pdb
import pprint
import random
import sys
import time
from typing import Any, Callable, Dict, List, Optional, Union

if sys.version_info >= (3, 9):
from typing import Literal
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think literal is 3.8+?

else:
Literal = None

from hamilton.lifecycle import NodeExecutionHook
from hamilton.lifecycle.api import NodeExecutionMethod

Expand Down Expand Up @@ -282,7 +288,7 @@ def run_after_node_execution(
task_id: Optional[str],
**future_kwargs: Any,
):
"""Executes after a node, whether or not is was successful. Does nothing, just runs pdb.set_trace().
"""Executes after a node, whether or not it was successful. Does nothing, just runs pdb.set_trace().

:param node_name: Name of the node
:param node_tags: Tags of the node
Expand Down Expand Up @@ -340,3 +346,118 @@ def run_before_node_execution(self, **future_kwargs: Any):
def run_after_node_execution(self, **future_kwargs: Any):
"""Does nothing"""
pass


def check_instance(obj: Any, type_: Any) -> bool:
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
"""This function checks if an object is an instance of a given type. It supports generic types as well.

:param obj: The object to check.
:param type_: The type to check against. This can be a generic type like List[int] or Dict[str, Any].
:return: True if the object is an instance of the type, False otherwise.
"""
if type_ == Any:
return True
# Get the origin of the type (i.e., the base class for generic types)
origin = getattr(type_, "__origin__", None)

# If the type has an origin, it's a generic type
if origin is not None:
# If the type is a Union type
if origin is Union:
return any(check_instance(obj, t) for t in type_.__args__)
elif origin is Literal:
return obj in type_.__args__
# Check if the object is an instance of the origin of the type
elif not isinstance(obj, origin):
return False

# If the type has arguments (i.e., it's a parameterized generic type like List[int])
if hasattr(type_, "__args__"):
# Get the element type(s) of the generic type
element_type = type_.__args__

# If the object is a dictionary
if isinstance(obj, dict):
all_items_meet_condition = True

# Iterate over each key-value pair in the dictionary
for key, value in obj.items():
# Check if the key is an instance of the first element type and the value is an instance of the second element type
key_is_correct_type = check_instance(key, element_type[0])
value_is_correct_type = check_instance(value, element_type[1])

# If either the key or the value is not the correct type, set the flag to False and break the loop
if not key_is_correct_type or not value_is_correct_type:
all_items_meet_condition = False
break

# Return the result
return all_items_meet_condition

# If the object is a list, set, or tuple
elif isinstance(obj, (list, set, tuple)):
element_type = element_type[0]
for i in obj:
if not check_instance(i, element_type):
return False
return True

# If the type is not a generic type, just use isinstance
return isinstance(obj, type_)


class FunctionInputOutputTypeChecker(NodeExecutionHook):
"""This lifecycle hook checks the input and output types of a function.

It is a simple, but very strict type check against the declared type with what was actually received.
E.g. if you don't want to check the types of a dictionary, don't annotate it with a type.
"""

def __init__(self, check_input: bool = True, check_output: bool = True):
"""Constructor.

:param check_input: check inputs to all functions
:param check_output: check outputs to all functions
"""
self.check_input = check_input
self.check_output = check_output

def run_before_node_execution(
self,
node_name: str,
node_tags: Dict[str, Any],
node_kwargs: Dict[str, Any],
node_return_type: type,
task_id: Optional[str],
run_id: str,
node_input_types: Dict[str, Any],
**future_kwargs: Any,
):
"""Checks that the result type matches the expected node return type."""
if self.check_input:
for input_name, input_value in node_kwargs.items():
if not check_instance(input_value, node_input_types[input_name]):
raise TypeError(
f"Node {node_name} received an input of type {type(input_value)} for {input_name}, expected {node_input_types[input_name]}"
)

def run_after_node_execution(
self,
node_name: str,
node_tags: Dict[str, Any],
node_kwargs: Dict[str, Any],
node_return_type: type,
result: Any,
error: Optional[Exception],
success: bool,
task_id: Optional[str],
run_id: str,
**future_kwargs: Any,
):
"""Checks that the result type matches the expected node return type."""
if self.check_output:
# Replace the isinstance check in your code with check_instance
if not check_instance(result, node_return_type):
raise TypeError(
f"Node {node_name} returned a result of type {type(result)}, expected {node_return_type}"
)
6 changes: 2 additions & 4 deletions tests/integrations/pandera/test_pandera_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def test_basic_pandera_decorator_dataframe_fails():
validator = pandera_validators.PanderaDataFrameValidator(schema=schema, importance="warn")
validation_result = validator.validate(df)
assert not validation_result.passes
assert (
"A total of 4 schema errors were found" in validation_result.message
) # TODO -- ensure this will stay constant with the contract
assert len(validation_result.diagnostics["schema_errors"]) == 4


def test_basic_pandera_decorator_dataframe_passes():
Expand Down Expand Up @@ -81,7 +79,7 @@ def test_basic_pandera_decorator_series_fails():
validator = pandera_validators.PanderaSeriesSchemaValidator(schema=schema, importance="warn")
validation_result = validator.validate(series)
assert not validation_result.passes
assert "A total of 1 schema error" in validation_result.message
assert len(validation_result.diagnostics["schema_errors"]) == 1


def test_basic_pandera_decorator_series_passes():
Expand Down
Loading