Skip to content

Commit

Permalink
simplify tracebacks by dropping internal stack frames
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-shields committed Aug 18, 2023
1 parent cb227c3 commit a1a0198
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 24 deletions.
12 changes: 11 additions & 1 deletion src/uberjob/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
#
from uberjob._util import fully_qualified_name
from uberjob._util.traceback import render_symbolic_traceback
from uberjob.graph import Call
from uberjob.graph import Call, Node


class NodeError(Exception):
"""An exception was raised during execution of a node."""

def __init__(self, node: Node):
super().__init__(
f"An exception was raised during execution of the following node: {node!r}."
)
self.node = node


class CallError(Exception):
Expand Down
22 changes: 10 additions & 12 deletions src/uberjob/_execution/run_function_on_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,12 @@
from contextlib import contextmanager
from typing import Dict, List, NamedTuple, Set

from uberjob._errors import NodeError
from uberjob._execution.scheduler import create_queue
from uberjob._util.networkx_util import assert_acyclic, predecessor_count
from uberjob.graph import Node


class NodeError(Exception):
"""An exception was raised during execution of a node."""

def __init__(self, node):
super().__init__(
f"An exception was raised during execution of the following node: {node!r}."
)
self.node = node


def coerce_worker_count(worker_count):
if worker_count is None:
# Matches concurrent.futures.ThreadPoolExecutor in Python 3.8+.
Expand Down Expand Up @@ -111,6 +102,14 @@ def prepare_nodes(graph) -> PreparedNodes:
)


def coerce_node_error(node: Node, exception: Exception) -> NodeError:
if isinstance(exception, NodeError):
return exception
node_error = NodeError(node)
node_error.__cause__ = exception
return node_error


def run_function_on_graph(
graph, fn, *, worker_count=None, max_errors=0, scheduler=None
):
Expand Down Expand Up @@ -142,8 +141,7 @@ def process_node(node):
with failure_lock:
error_count += 1
if not first_node_error:
first_node_error = NodeError(node)
first_node_error.__cause__ = exception
first_node_error = coerce_node_error(node, exception)
if max_errors is not None and error_count > max_errors:
stop = True
else:
Expand Down
12 changes: 7 additions & 5 deletions src/uberjob/_execution/run_physical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Functionality for executing a physical plan"""
from typing import Any, Callable, Dict, NamedTuple, Optional

from uberjob._errors import create_chained_call_error
from uberjob._errors import NodeError, create_chained_call_error
from uberjob._execution.run_function_on_graph import run_function_on_graph
from uberjob._graph import get_full_call_scope
from uberjob._plan import Plan
Expand All @@ -38,10 +38,10 @@ def __init__(self, args, kwargs, result):
self.kwargs = kwargs
self.result = result

def run(self, fn):
def run(self, fn, retry):
args = [arg.value for arg in self.args]
kwargs = {name: arg.value for name, arg in self.kwargs.items()}
self.result.value = fn(*args, **kwargs)
self.result.value = retry(fn)(*args, **kwargs)


def _create_bound_call(
Expand Down Expand Up @@ -98,14 +98,16 @@ def process(node):
progress_observer.increment_running(section="run", scope=scope)
bound_call = bound_call_lookup[node]
try:
bound_call.value.run(retry(node.fn))
bound_call.value.run(node.fn, retry)
except Exception as exception:
# Drop internal frames
exception.__traceback__ = exception.__traceback__.tb_next.tb_next
progress_observer.increment_failed(
section="run",
scope=scope,
exception=create_chained_call_error(node, exception),
)
raise
raise NodeError(node) from exception
finally:
bound_call.value = None
progress_observer.increment_completed(section="run", scope=scope)
Expand Down
8 changes: 6 additions & 2 deletions src/uberjob/_transformations/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import datetime as dt
from typing import Optional, Set, Tuple

from uberjob._errors import create_chained_call_error
from uberjob._errors import NodeError, create_chained_call_error
from uberjob._execution.run_function_on_graph import run_function_on_graph
from uberjob._graph import get_full_call_scope
from uberjob._plan import Plan
Expand Down Expand Up @@ -110,12 +110,16 @@ def process_with_callbacks(node):
try:
process(node)
except Exception as exception:
# Drop internal frames
exception.__traceback__ = (
exception.__traceback__.tb_next.tb_next.tb_next
)
progress_observer.increment_failed(
section="stale",
scope=scope,
exception=create_chained_call_error(node, exception),
)
raise
raise NodeError(node) from exception
progress_observer.increment_completed(section="stale", scope=scope)
else:
process(node)
Expand Down
39 changes: 39 additions & 0 deletions tests/test_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,42 @@ def test_retry_validation(self):
uberjob.run(plan, retry=0)
with self.assertRaises(ValueError):
uberjob.run(plan, retry=-1)

def test_traceback_manipulation(self):
def x():
raise IOError("buzz")

def y():
try:
return x()
except Exception as e:
raise ValueError("fizz") from e

def z():
return y()

plan = uberjob.Plan()
call = plan.call(z)
with self.assert_call_exception(
expected_exception_chain_traceback_summary=[["z", "y"], ["y", "x"]]
):
uberjob.run(plan, output=call)

def bad_retry1(f):
raise Exception()

with self.assert_call_exception(
expected_exception_chain_traceback_summary=[["bad_retry1"]]
):
uberjob.run(plan, output=call, retry=bad_retry1)

def bad_retry2(f):
def wrapper(*args, **kwargs):
raise ValueError()

return wrapper

with self.assert_call_exception(
expected_exception_chain_traceback_summary=[["wrapper"]]
):
uberjob.run(plan, output=call, retry=bad_retry2)
59 changes: 59 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,62 @@ def test_registry_copy(self):
self.assertNotIn(y, registry)
self.assertIn(x, registry_copy)
self.assertEqual(registry[x], registry_copy[x])

def test_traceback_manipulation(self):
def buzz():
raise ValueError()

def fizz():
try:
return buzz()
except ValueError as e:
raise Exception() from e

class BadValueStore(uberjob.ValueStore):
def read(self):
raise NotImplementedError()

def write(self):
raise NotImplementedError()

def get_modified_time(self):
return fizz()

class BadValueStore2(uberjob.ValueStore):
def read(self):
raise NotImplementedError()

def write(self):
raise NotImplementedError()

def get_modified_time(self):
return 7

plan = uberjob.Plan()
registry = uberjob.Registry()
node = registry.source(plan, BadValueStore())

with self.assert_call_exception(
expected_exception_chain_traceback_summary=[
["get_modified_time", "fizz"],
["fizz", "buzz"],
]
):
uberjob.run(plan, registry=registry, output=node)

def bad_retry(f):
raise Exception()

with self.assert_call_exception(
expected_exception_chain_traceback_summary=[["bad_retry"]]
):
uberjob.run(plan, registry=registry, output=node, retry=bad_retry)

plan = uberjob.Plan()
registry = uberjob.Registry()
node = registry.source(plan, BadValueStore2())

with self.assert_call_exception(
expected_exception_chain_traceback_summary=[["_to_naive_utc_time"]]
):
uberjob.run(plan, registry=registry, output=node)
33 changes: 29 additions & 4 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,48 @@
from uberjob._util.traceback import StackFrame


def _traceback_summary(traceback) -> list[str]:
result = []
while traceback is not None:
result.append(traceback.tb_frame.f_code.co_name)
traceback = traceback.tb_next
return result


def _exception_chain_traceback_summary(exception) -> list[list[str]]:
result = []
while exception is not None:
result.append(_traceback_summary(exception.__traceback__))
exception = exception.__cause__
return result


@contextmanager
def assert_call_exception(
test_case, expected_exception=None, expected_regex=None, expected_stack_frame=None
test_case,
expected_exception=None,
expected_regex=None,
expected_stack_frame=None,
expected_exception_chain_traceback_summary=None,
):
try:
yield
except CallError as e:
test_case.assertRegex(str(e), "An exception was raised in a symbolic call.*")
if expected_exception:
if expected_exception is not None:
test_case.assertIsInstance(e.__cause__, expected_exception)
if expected_regex:
if expected_regex is not None:
test_case.assertRegex(str(e.__cause__), expected_regex)
if expected_stack_frame:
if expected_stack_frame is not None:
stack_frame = e.call.stack_frame
test_case.assertEqual(stack_frame.name, expected_stack_frame.name)
test_case.assertEqual(stack_frame.path, expected_stack_frame.path)
test_case.assertEqual(stack_frame.line, expected_stack_frame.line)
if expected_exception_chain_traceback_summary is not None:
test_case.assertEqual(
_exception_chain_traceback_summary(e.__cause__),
expected_exception_chain_traceback_summary,
)
else:
test_case.fail("Call exception not raised.")

Expand Down

0 comments on commit a1a0198

Please sign in to comment.