diff --git a/src/uberjob/_errors.py b/src/uberjob/_errors.py index 56bf78a..3ba2409 100644 --- a/src/uberjob/_errors.py +++ b/src/uberjob/_errors.py @@ -15,7 +15,17 @@ # from uberjob._util import fully_qualified_name from uberjob._util.traceback import render_symbolic_traceback -from uberjob.graph import Call +from uberjob.graph import Call, Node + + +class NodeError(Exception): + """An exception was raised during execution of a node.""" + + def __init__(self, node: Node): + super().__init__( + f"An exception was raised during execution of the following node: {node!r}." + ) + self.node = node class CallError(Exception): diff --git a/src/uberjob/_execution/run_function_on_graph.py b/src/uberjob/_execution/run_function_on_graph.py index 8051773..c53863c 100644 --- a/src/uberjob/_execution/run_function_on_graph.py +++ b/src/uberjob/_execution/run_function_on_graph.py @@ -19,21 +19,12 @@ from contextlib import contextmanager from typing import Dict, List, NamedTuple, Set +from uberjob._errors import NodeError from uberjob._execution.scheduler import create_queue from uberjob._util.networkx_util import assert_acyclic, predecessor_count from uberjob.graph import Node -class NodeError(Exception): - """An exception was raised during execution of a node.""" - - def __init__(self, node): - super().__init__( - f"An exception was raised during execution of the following node: {node!r}." - ) - self.node = node - - def coerce_worker_count(worker_count): if worker_count is None: # Matches concurrent.futures.ThreadPoolExecutor in Python 3.8+. @@ -111,6 +102,14 @@ def prepare_nodes(graph) -> PreparedNodes: ) +def coerce_node_error(node: Node, exception: Exception) -> NodeError: + if isinstance(exception, NodeError): + return exception + node_error = NodeError(node) + node_error.__cause__ = exception + return node_error + + def run_function_on_graph( graph, fn, *, worker_count=None, max_errors=0, scheduler=None ): @@ -142,8 +141,7 @@ def process_node(node): with failure_lock: error_count += 1 if not first_node_error: - first_node_error = NodeError(node) - first_node_error.__cause__ = exception + first_node_error = coerce_node_error(node, exception) if max_errors is not None and error_count > max_errors: stop = True else: diff --git a/src/uberjob/_execution/run_physical.py b/src/uberjob/_execution/run_physical.py index f0aeeb3..3de78f9 100644 --- a/src/uberjob/_execution/run_physical.py +++ b/src/uberjob/_execution/run_physical.py @@ -16,7 +16,7 @@ """Functionality for executing a physical plan""" from typing import Any, Callable, Dict, NamedTuple, Optional -from uberjob._errors import create_chained_call_error +from uberjob._errors import NodeError, create_chained_call_error from uberjob._execution.run_function_on_graph import run_function_on_graph from uberjob._graph import get_full_call_scope from uberjob._plan import Plan @@ -38,10 +38,10 @@ def __init__(self, args, kwargs, result): self.kwargs = kwargs self.result = result - def run(self, fn): + def run(self, fn, retry): args = [arg.value for arg in self.args] kwargs = {name: arg.value for name, arg in self.kwargs.items()} - self.result.value = fn(*args, **kwargs) + self.result.value = retry(fn)(*args, **kwargs) def _create_bound_call( @@ -98,14 +98,16 @@ def process(node): progress_observer.increment_running(section="run", scope=scope) bound_call = bound_call_lookup[node] try: - bound_call.value.run(retry(node.fn)) + bound_call.value.run(node.fn, retry) except Exception as exception: + # Drop internal frames + exception.__traceback__ = exception.__traceback__.tb_next.tb_next progress_observer.increment_failed( section="run", scope=scope, exception=create_chained_call_error(node, exception), ) - raise + raise NodeError(node) from exception finally: bound_call.value = None progress_observer.increment_completed(section="run", scope=scope) diff --git a/src/uberjob/_transformations/caching.py b/src/uberjob/_transformations/caching.py index 1b3e28a..348eca7 100644 --- a/src/uberjob/_transformations/caching.py +++ b/src/uberjob/_transformations/caching.py @@ -17,7 +17,7 @@ import datetime as dt from typing import Optional, Set, Tuple -from uberjob._errors import create_chained_call_error +from uberjob._errors import NodeError, create_chained_call_error from uberjob._execution.run_function_on_graph import run_function_on_graph from uberjob._graph import get_full_call_scope from uberjob._plan import Plan @@ -110,12 +110,16 @@ def process_with_callbacks(node): try: process(node) except Exception as exception: + # Drop internal frames + exception.__traceback__ = ( + exception.__traceback__.tb_next.tb_next.tb_next + ) progress_observer.increment_failed( section="stale", scope=scope, exception=create_chained_call_error(node, exception), ) - raise + raise NodeError(node) from exception progress_observer.increment_completed(section="stale", scope=scope) else: process(node) diff --git a/tests/test_plan.py b/tests/test_plan.py index 4a7e839..459f8c1 100644 --- a/tests/test_plan.py +++ b/tests/test_plan.py @@ -334,3 +334,42 @@ def test_retry_validation(self): uberjob.run(plan, retry=0) with self.assertRaises(ValueError): uberjob.run(plan, retry=-1) + + def test_traceback_manipulation(self): + def x(): + raise IOError("buzz") + + def y(): + try: + return x() + except Exception as e: + raise ValueError("fizz") from e + + def z(): + return y() + + plan = uberjob.Plan() + call = plan.call(z) + with self.assert_call_exception( + expected_exception_chain_traceback_summary=[["z", "y"], ["y", "x"]] + ): + uberjob.run(plan, output=call) + + def bad_retry1(f): + raise Exception() + + with self.assert_call_exception( + expected_exception_chain_traceback_summary=[["bad_retry1"]] + ): + uberjob.run(plan, output=call, retry=bad_retry1) + + def bad_retry2(f): + def wrapper(*args, **kwargs): + raise ValueError() + + return wrapper + + with self.assert_call_exception( + expected_exception_chain_traceback_summary=[["wrapper"]] + ): + uberjob.run(plan, output=call, retry=bad_retry2) diff --git a/tests/test_registry.py b/tests/test_registry.py index c156c04..538658a 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -343,3 +343,62 @@ def test_registry_copy(self): self.assertNotIn(y, registry) self.assertIn(x, registry_copy) self.assertEqual(registry[x], registry_copy[x]) + + def test_traceback_manipulation(self): + def buzz(): + raise ValueError() + + def fizz(): + try: + return buzz() + except ValueError as e: + raise Exception() from e + + class BadValueStore(uberjob.ValueStore): + def read(self): + raise NotImplementedError() + + def write(self): + raise NotImplementedError() + + def get_modified_time(self): + return fizz() + + class BadValueStore2(uberjob.ValueStore): + def read(self): + raise NotImplementedError() + + def write(self): + raise NotImplementedError() + + def get_modified_time(self): + return 7 + + plan = uberjob.Plan() + registry = uberjob.Registry() + node = registry.source(plan, BadValueStore()) + + with self.assert_call_exception( + expected_exception_chain_traceback_summary=[ + ["get_modified_time", "fizz"], + ["fizz", "buzz"], + ] + ): + uberjob.run(plan, registry=registry, output=node) + + def bad_retry(f): + raise Exception() + + with self.assert_call_exception( + expected_exception_chain_traceback_summary=[["bad_retry"]] + ): + uberjob.run(plan, registry=registry, output=node, retry=bad_retry) + + plan = uberjob.Plan() + registry = uberjob.Registry() + node = registry.source(plan, BadValueStore2()) + + with self.assert_call_exception( + expected_exception_chain_traceback_summary=[["_to_naive_utc_time"]] + ): + uberjob.run(plan, registry=registry, output=node) diff --git a/tests/util.py b/tests/util.py index 34c9432..f9834f0 100644 --- a/tests/util.py +++ b/tests/util.py @@ -21,23 +21,48 @@ from uberjob._util.traceback import StackFrame +def _traceback_summary(traceback) -> list[str]: + result = [] + while traceback is not None: + result.append(traceback.tb_frame.f_code.co_name) + traceback = traceback.tb_next + return result + + +def _exception_chain_traceback_summary(exception) -> list[list[str]]: + result = [] + while exception is not None: + result.append(_traceback_summary(exception.__traceback__)) + exception = exception.__cause__ + return result + + @contextmanager def assert_call_exception( - test_case, expected_exception=None, expected_regex=None, expected_stack_frame=None + test_case, + expected_exception=None, + expected_regex=None, + expected_stack_frame=None, + expected_exception_chain_traceback_summary=None, ): try: yield except CallError as e: test_case.assertRegex(str(e), "An exception was raised in a symbolic call.*") - if expected_exception: + if expected_exception is not None: test_case.assertIsInstance(e.__cause__, expected_exception) - if expected_regex: + if expected_regex is not None: test_case.assertRegex(str(e.__cause__), expected_regex) - if expected_stack_frame: + if expected_stack_frame is not None: stack_frame = e.call.stack_frame test_case.assertEqual(stack_frame.name, expected_stack_frame.name) test_case.assertEqual(stack_frame.path, expected_stack_frame.path) test_case.assertEqual(stack_frame.line, expected_stack_frame.line) + if expected_exception_chain_traceback_summary is not None: + test_case.assertEqual( + _exception_chain_traceback_summary(e.__cause__), + expected_exception_chain_traceback_summary, + ) else: test_case.fail("Call exception not raised.")