Skip to content

Commit

Permalink
Merge pull request #22 from twosigma/serialize-exceptions
Browse files Browse the repository at this point in the history
Fixed NodeError and CallError not being serializable and added test c…
  • Loading branch information
daniel-shields authored Feb 9, 2024
2 parents a4021ee + 1d1ba0b commit 37a6c3e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/uberjob/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(self, node: Node):
)
self.node = node

def __reduce__(self):
return NodeError, (self.node,)


class CallError(Exception):
"""
Expand All @@ -46,6 +49,9 @@ def __init__(self, call: Call):
)
self.call = call

def __reduce__(self):
return CallError, (self.call,)


class NotTransformedError(Exception):
"""An expected transformation was not applied."""
Expand Down
8 changes: 8 additions & 0 deletions src/uberjob/_util/traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,18 @@ def __repr__(self):
)


TruncatedStackFrame = None


class TruncatedStackFrameType:
def __repr__(self):
return "TruncatedStackFrame"

def __new__(cls, *args, **kwargs):
if TruncatedStackFrame is not None:
return TruncatedStackFrame
return super().__new__(cls, *args, **kwargs)


TruncatedStackFrame = TruncatedStackFrameType()

Expand Down
32 changes: 32 additions & 0 deletions tests/test_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import operator
import os
import pathlib
import pickle
import re
import tempfile

import networkx as nx

import uberjob
from uberjob._errors import NodeError
from uberjob._util.traceback import get_stack_frame
from uberjob.graph import Call
from uberjob.progress import console_progress, default_progress, html_progress
Expand Down Expand Up @@ -373,3 +375,33 @@ def wrapper(*args, **kwargs):
expected_exception_chain_traceback_summary=[["wrapper"]]
):
uberjob.run(plan, output=call, retry=bad_retry2)

def test_serialize_call_error(self):
plan = uberjob.Plan()
call = plan.call(operator.truediv, 1, 0)
exception = None
try:
uberjob.run(plan, output=call)
except uberjob.CallError as e:
exception = e
self.assertIsNotNone(exception)
pickled_exception = pickle.dumps(exception)
unpickled_exception = pickle.loads(pickled_exception)
self.assertIsInstance(unpickled_exception, uberjob.CallError)
self.assertIsInstance(unpickled_exception.call, Call)
self.assertIs(unpickled_exception.call.fn, operator.truediv)

def test_serialize_node_error(self):
plan = uberjob.Plan()
call = plan.call(pow, 2, 2)
exception = None
try:
raise NodeError(call)
except NodeError as e:
exception = e
self.assertIsNotNone(exception)
pickled_exception = pickle.dumps(exception)
unpickled_exception = pickle.loads(pickled_exception)
self.assertIsInstance(unpickled_exception, NodeError)
self.assertIsInstance(unpickled_exception.node, Call)
self.assertIs(unpickled_exception.node.fn, pow)

0 comments on commit 37a6c3e

Please sign in to comment.