-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds FunctionInputOutputTypeChecker #757
Merged
Merged
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -174,6 +174,7 @@ def run_before_node_execution( | |
node_return_type: type, | ||
task_id: Optional[str], | ||
run_id: str, | ||
node_input_types: Dict[str, Any], | ||
**future_kwargs: Any, | ||
): | ||
"""Hook that is executed prior to node execution. | ||
|
@@ -184,6 +185,7 @@ def run_before_node_execution( | |
:param node_return_type: Return type of the node | ||
:param task_id: The ID of the task, none if not in a task-based environment | ||
:param run_id: Run ID (unique in process scope) of the current run. Use this to track state. | ||
:param node_input_types: the input types to the node and what it is expecting | ||
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility | ||
""" | ||
pass | ||
|
@@ -206,6 +208,7 @@ def pre_node_execute( | |
node_return_type=node_.type, | ||
task_id=task_id, | ||
run_id=run_id, | ||
node_input_types={k: v[0] for k, v in node_.input_types.items()}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add TODO for integrating the |
||
) | ||
|
||
@abc.abstractmethod | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,15 @@ | |
import pdb | ||
import pprint | ||
import random | ||
import sys | ||
import time | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
||
if sys.version_info >= (3, 9): | ||
from typing import Literal | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think literal is |
||
else: | ||
Literal = None | ||
|
||
from hamilton.lifecycle import NodeExecutionHook | ||
from hamilton.lifecycle.api import NodeExecutionMethod | ||
|
||
|
@@ -282,7 +288,7 @@ def run_after_node_execution( | |
task_id: Optional[str], | ||
**future_kwargs: Any, | ||
): | ||
"""Executes after a node, whether or not is was successful. Does nothing, just runs pdb.set_trace(). | ||
"""Executes after a node, whether or not it was successful. Does nothing, just runs pdb.set_trace(). | ||
|
||
:param node_name: Name of the node | ||
:param node_tags: Tags of the node | ||
|
@@ -340,3 +346,118 @@ def run_before_node_execution(self, **future_kwargs: Any): | |
def run_after_node_execution(self, **future_kwargs: Any): | ||
"""Does nothing""" | ||
pass | ||
|
||
|
||
def check_instance(obj: Any, type_: Any) -> bool: | ||
skrawcz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""This function checks if an object is an instance of a given type. It supports generic types as well. | ||
|
||
:param obj: The object to check. | ||
:param type_: The type to check against. This can be a generic type like List[int] or Dict[str, Any]. | ||
:return: True if the object is an instance of the type, False otherwise. | ||
""" | ||
if type_ == Any: | ||
return True | ||
# Get the origin of the type (i.e., the base class for generic types) | ||
origin = getattr(type_, "__origin__", None) | ||
|
||
# If the type has an origin, it's a generic type | ||
if origin is not None: | ||
# If the type is a Union type | ||
if origin is Union: | ||
return any(check_instance(obj, t) for t in type_.__args__) | ||
elif origin is Literal: | ||
return obj in type_.__args__ | ||
# Check if the object is an instance of the origin of the type | ||
elif not isinstance(obj, origin): | ||
return False | ||
|
||
# If the type has arguments (i.e., it's a parameterized generic type like List[int]) | ||
if hasattr(type_, "__args__"): | ||
# Get the element type(s) of the generic type | ||
element_type = type_.__args__ | ||
|
||
# If the object is a dictionary | ||
if isinstance(obj, dict): | ||
all_items_meet_condition = True | ||
|
||
# Iterate over each key-value pair in the dictionary | ||
for key, value in obj.items(): | ||
# Check if the key is an instance of the first element type and the value is an instance of the second element type | ||
key_is_correct_type = check_instance(key, element_type[0]) | ||
value_is_correct_type = check_instance(value, element_type[1]) | ||
|
||
# If either the key or the value is not the correct type, set the flag to False and break the loop | ||
if not key_is_correct_type or not value_is_correct_type: | ||
all_items_meet_condition = False | ||
break | ||
|
||
# Return the result | ||
return all_items_meet_condition | ||
|
||
# If the object is a list, set, or tuple | ||
elif isinstance(obj, (list, set, tuple)): | ||
element_type = element_type[0] | ||
for i in obj: | ||
if not check_instance(i, element_type): | ||
return False | ||
return True | ||
|
||
# If the type is not a generic type, just use isinstance | ||
return isinstance(obj, type_) | ||
|
||
|
||
class FunctionInputOutputTypeChecker(NodeExecutionHook): | ||
"""This lifecycle hook checks the input and output types of a function. | ||
|
||
It is a simple, but very strict type check against the declared type with what was actually received. | ||
E.g. if you don't want to check the types of a dictionary, don't annotate it with a type. | ||
""" | ||
|
||
def __init__(self, check_input: bool = True, check_output: bool = True): | ||
"""Constructor. | ||
|
||
:param check_input: check inputs to all functions | ||
:param check_output: check outputs to all functions | ||
""" | ||
self.check_input = check_input | ||
self.check_output = check_output | ||
|
||
def run_before_node_execution( | ||
self, | ||
node_name: str, | ||
node_tags: Dict[str, Any], | ||
node_kwargs: Dict[str, Any], | ||
node_return_type: type, | ||
task_id: Optional[str], | ||
run_id: str, | ||
node_input_types: Dict[str, Any], | ||
**future_kwargs: Any, | ||
): | ||
"""Checks that the result type matches the expected node return type.""" | ||
if self.check_input: | ||
for input_name, input_value in node_kwargs.items(): | ||
if not check_instance(input_value, node_input_types[input_name]): | ||
raise TypeError( | ||
f"Node {node_name} received an input of type {type(input_value)} for {input_name}, expected {node_input_types[input_name]}" | ||
) | ||
|
||
def run_after_node_execution( | ||
self, | ||
node_name: str, | ||
node_tags: Dict[str, Any], | ||
node_kwargs: Dict[str, Any], | ||
node_return_type: type, | ||
result: Any, | ||
error: Optional[Exception], | ||
success: bool, | ||
task_id: Optional[str], | ||
run_id: str, | ||
**future_kwargs: Any, | ||
): | ||
"""Checks that the result type matches the expected node return type.""" | ||
if self.check_output: | ||
# Replace the isinstance check in your code with check_instance | ||
if not check_instance(result, node_return_type): | ||
raise TypeError( | ||
f"Node {node_name} returned a result of type {type(result)}, expected {node_return_type}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm we should really have a context manager that does this... But this works for now.