Skip to content

Commit

Permalink
Mark run as failed in case of error using ctx manager (#1755)
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksanderWWW authored May 14, 2024
1 parent d4d7fdd commit 159f577
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 30 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## [UNRELEASED] neptune 1.10.4

### Fixes
- Fixed run not failing in case of an exception if context manager was used ([#1755](https://github.com/neptune-ai/neptune-client/pull/1755))


## neptune 1.10.3

### Fixes
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ module = [
"neptune.internal.utils.s3",
"neptune.internal.utils.source_code",
"neptune.internal.utils.traceback_job",
"neptune.internal.utils.uncaught_exception_handler",
"neptune.internal.websockets.websocket_signals_background_job",
"neptune.internal.websockets.websockets_factory",
"neptune.legacy",
Expand Down
71 changes: 42 additions & 29 deletions src/neptune/internal/utils/uncaught_exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,60 +20,73 @@
import traceback
import uuid
from platform import node as get_hostname
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Type,
)

from neptune.internal.utils.logger import get_logger

if TYPE_CHECKING:
pass

_logger = get_logger()

SYS_UNCAUGHT_EXCEPTION_HANDLER_TYPE = Callable[[Type[BaseException], BaseException, Optional[TracebackType]], Any]


class UncaughtExceptionHandler:
def __init__(self):
self._previous_uncaught_exception_handler = None
self._handlers = dict()
def __init__(self) -> None:
self._previous_uncaught_exception_handler: Optional[SYS_UNCAUGHT_EXCEPTION_HANDLER_TYPE] = None
self._handlers: Dict[uuid.UUID, Callable[[List[str]], None]] = dict()
self._lock = threading.Lock()

def activate(self):
def trigger(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
header_lines = [
f"An uncaught exception occurred while run was active on worker {get_hostname()}.",
"Marking run as failed",
"Traceback:",
]

traceback_lines = header_lines + traceback.format_tb(exc_tb) + str(exc_val).split("\n")
for _, handler in self._handlers.items():
handler(traceback_lines)

def activate(self) -> None:
with self._lock:
this = self

def exception_handler(exc_type, exc_val, exc_tb):
header_lines = [
f"An uncaught exception occurred while run was active on worker {get_hostname()}.",
"Marking run as failed",
"Traceback:",
]

traceback_lines = header_lines + traceback.format_tb(exc_tb) + str(exc_val).split("\n")
for _, handler in self._handlers.items():
handler(traceback_lines)

this._previous_uncaught_exception_handler(exc_type, exc_val, exc_tb)
if self._previous_uncaught_exception_handler is not None:
return
self._previous_uncaught_exception_handler = sys.excepthook
sys.excepthook = self.exception_handler

if self._previous_uncaught_exception_handler is None:
self._previous_uncaught_exception_handler = sys.excepthook
sys.excepthook = exception_handler

def deactivate(self):
def deactivate(self) -> None:
with self._lock:
if self._previous_uncaught_exception_handler is None:
return
sys.excepthook = self._previous_uncaught_exception_handler
self._previous_uncaught_exception_handler = None

def register(self, uid: uuid.UUID, handler: Callable[[List[str]], None]):
def register(self, uid: uuid.UUID, handler: Callable[[List[str]], None]) -> None:
with self._lock:
self._handlers[uid] = handler

def unregister(self, uid: uuid.UUID):
def unregister(self, uid: uuid.UUID) -> None:
with self._lock:
if uid in self._handlers:
del self._handlers[uid]

def exception_handler(self, *args: Any, **kwargs: Any) -> None:
self.trigger(*args, **kwargs)

if self._previous_uncaught_exception_handler is not None:
self._previous_uncaught_exception_handler(*args, **kwargs)


instance = UncaughtExceptionHandler()
1 change: 1 addition & 0 deletions src/neptune/metadata_containers/metadata_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def _write_initial_attributes(self):
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_tb is not None:
traceback.print_exception(exc_type, exc_val, exc_tb)
uncaught_exception_handler.trigger(exc_type, exc_val, exc_tb)
self.stop()

def __getattr__(self, item):
Expand Down
19 changes: 19 additions & 0 deletions tests/e2e/standard/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,25 @@ def test_tracking_uncommitted_changes(self, repo, environment):
with open("diff.patch") as fp:
assert "some-content" in fp.read()

def test_failing_on_exception_if_in_context_manager(self, environment):
run_id = ""

try:
with neptune.init_run(project=environment.project) as run:
run_id = run["sys/id"].fetch()
raise Exception()
except Exception:
pass

with neptune.init_run(with_id=run_id, project=environment.project) as run:
assert run["sys/failed"].fetch() is True

monitoring_hash = list(run.get_structure()["monitoring"].items())[0][0]
assert run.exists(f"monitoring/{monitoring_hash}/traceback")

traceback_df = run[f"monitoring/{monitoring_hash}/traceback"].fetch_values()
assert "Marking run as failed" in traceback_df["value"].to_list()


class TestInitProject(BaseE2ETest):
def test_resuming_project(self, environment):
Expand Down

0 comments on commit 159f577

Please sign in to comment.