Skip to content

Commit

Permalink
Adds graceful fail-over
Browse files Browse the repository at this point in the history
This is still a bit of a WIP, but the API will stay backwards compatible
so I'm OK putting it in the stdlib.

This efectively cascades through null (customizable sentinel value)
results, not running a node if any of its dependencies are null.
  • Loading branch information
elijahbenizzy committed Jun 4, 2024
1 parent 7700507 commit e7bda85
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 1 deletion.
1 change: 1 addition & 0 deletions hamilton/lifecycle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .base import LifecycleAdapter # noqa: F401
from .default import ( # noqa: F401
FunctionInputOutputTypeChecker,
GracefulErrorAdapter,
PDBDebugger,
PrintLn,
SlowDownYouMoveTooFast,
Expand Down
38 changes: 37 additions & 1 deletion hamilton/lifecycle/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
import shelve
import time
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union

from hamilton import graph_types, htypes
from hamilton.graph_types import HamiltonGraph
Expand Down Expand Up @@ -506,3 +506,39 @@ def run_after_node_execution(
raise TypeError(
f"Node {node_name} returned a result of type {type(result)}, expected {node_return_type}"
)


SENTINEL_DEFAULT = None # sentinel value -- lazy for now


class GracefulErrorAdapter(NodeExecutionMethod):
"""Gracefully handles errors in a graph's execution. This allows you to proceed despite failure,
dynamically pruning branches. While it still runs every node, it replaces them with no-ops if any upstream
required dependencies fail (including optional dependencies).
"""

def __init__(self, error_to_catch: Type[Exception], sentinel_value: Any = SENTINEL_DEFAULT):
"""Initializes the adapter. Allows you to customize the error to catch (which exception
your graph will throw to indicate failure), as well as the sentinel value to use in place of
a node's result if it fails (this defaults to ``None``).
Note that this is currently only compatible with the dict-based result builder (use at your
own risk with pandas series, etc...).
:param error_to_catch: The error to catch
:param sentinel_value: The sentinel value to use in place of a node's result if it fails
"""
self.error_to_catch = error_to_catch
self.sentinel_value = sentinel_value

def run_to_execute_node(
self, *, node_callable: Any, node_kwargs: Dict[str, Any], **future_kwargs: Any
) -> Any:
"""Executes a node. If the node fails, returns the sentinel value."""
for key, value in node_kwargs.items():
if value == self.sentinel_value: # == versus is
return self.sentinel_value # cascade it through
try:
return node_callable(**node_kwargs)
except self.error_to_catch:
return self.sentinel_value
69 changes: 69 additions & 0 deletions scratch/graceful_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from hamilton import driver
from hamilton.lifecycle import default


class DoNotProceed(Exception):
pass


def working_res_1() -> int:
return 1


def working_res_2() -> int:
return 2


def do_not_proceed_1(working_res_1: int, working_res_2: int) -> int:
raise DoNotProceed()


def proceed_1(working_res_1: int, working_res_2: int) -> int:
return working_res_1 + working_res_2


def proceed_2(working_res_1: int, working_res_2: int) -> int:
return working_res_1 * working_res_2


def short_circuited_1(working_res_1: int, working_res_2: int, do_not_proceed_1: int) -> int:
return 1 # this should not be reached


def proceed_3(working_res_1: int, working_res_2: int) -> int:
return working_res_1 - working_res_2


def do_not_proceed_2(proceed_1: int, proceed_2: int, proceed_3: int) -> int:
raise DoNotProceed()


def short_circuited_2(proceed_1: int, proceed_2: int, proceed_3: int, do_not_proceed_2: int) -> int:
return 1 # this should not be reached


def short_circuited_3(short_circuited_1: int) -> int:
return 1 # this should not be reached


def proceed_4(proceed_1: int, proceed_2: int, proceed_3: int) -> int:
return proceed_1 + proceed_2 + proceed_3


if __name__ == "__main__":
import __main__

dr = (
driver.Builder()
.with_modules(__main__)
.with_adapters(
default.GracefulErrorAdapter(error_to_catch=DoNotProceed, sentinel_value=None)
)
.build()
)
dr.display_all_functions()
vars = dr.list_available_variables()
res = dr.execute(vars)
import pprint

pprint.pprint(res)

0 comments on commit e7bda85

Please sign in to comment.