Skip to content

Commit

Permalink
Handling condition where Entity.cause is not a dict. (aws#267)
Browse files Browse the repository at this point in the history
* Handling condition where cause is not a dict. Exceptions should be appended not replaced.

* Adding more test cases

* Minor fixes to some tests

* Some type checking in python2 may require to import unicode literals

* Checking python version for type comparison
  • Loading branch information
srprash authored and Tyler Hargraves committed Mar 22, 2022
1 parent 338b0e4 commit e8c60f2
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
13 changes: 11 additions & 2 deletions aws_xray_sdk/core/models/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def add_exception(self, exception, stack, remote=False):
"""
Add an exception to trace entities.
:param Exception exception: the catched exception.
:param Exception exception: the caught exception.
:param list stack: the output from python built-in
`traceback.extract_stack()`.
:param bool remote: If False it means it's a client error
Expand All @@ -224,7 +224,16 @@ def add_exception(self, exception, stack, remote=False):
setattr(self, 'cause', getattr(exception, '_cause_id'))
return

exceptions = []
if not isinstance(self.cause, dict):
log.warning("The current cause object is not a dict but an id: {}. Resetting the cause and recording the "
"current exception".format(self.cause))
self.cause = {}

if 'exceptions' in self.cause:
exceptions = self.cause['exceptions']
else:
exceptions = []

exceptions.append(Throwable(exception, stack, remote))

self.cause['exceptions'] = exceptions
Expand Down
69 changes: 69 additions & 0 deletions tests/test_trace_entities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: iso-8859-15 -*-

import pytest
import sys

from aws_xray_sdk.core.models.segment import Segment
from aws_xray_sdk.core.models.subsegment import Subsegment
Expand Down Expand Up @@ -194,3 +196,70 @@ def test_missing_parent_segment():

with pytest.raises(SegmentNotFoundException):
Subsegment('name', 'local', None)


def test_add_exception():
segment = Segment('seg')
exception = Exception("testException")
stack = [['path', 'line', 'label']]
segment.add_exception(exception=exception, stack=stack)
segment.close()

cause = segment.cause
assert 'exceptions' in cause
exceptions = cause['exceptions']
assert len(exceptions) == 1
assert 'working_directory' in cause
exception = exceptions[0]
assert 'testException' == exception.message
expected_stack = [{'path': 'path', 'line': 'line', 'label': 'label'}]
assert expected_stack == exception.stack


def test_add_exception_referencing():
segment = Segment('seg')
subseg = Subsegment('subseg', 'remote', segment)
exception = Exception("testException")
stack = [['path', 'line', 'label']]
subseg.add_exception(exception=exception, stack=stack)
segment.add_exception(exception=exception, stack=stack)
subseg.close()
segment.close()

seg_cause = segment.cause
subseg_cause = subseg.cause

assert isinstance(subseg_cause, dict)
if sys.version_info.major == 2:
assert isinstance(seg_cause, basestring)
else:
assert isinstance(seg_cause, str)
assert seg_cause == subseg_cause['exceptions'][0].id


def test_add_exception_cause_resetting():
segment = Segment('seg')
subseg = Subsegment('subseg', 'remote', segment)
exception = Exception("testException")
stack = [['path', 'line', 'label']]
subseg.add_exception(exception=exception, stack=stack)
segment.add_exception(exception=exception, stack=stack)

segment.add_exception(exception=Exception("newException"), stack=stack)
subseg.close()
segment.close()

seg_cause = segment.cause
assert isinstance(seg_cause, dict)
assert 'newException' == seg_cause['exceptions'][0].message


def test_add_exception_appending_exceptions():
segment = Segment('seg')
stack = [['path', 'line', 'label']]
segment.add_exception(exception=Exception("testException"), stack=stack)
segment.add_exception(exception=Exception("newException"), stack=stack)
segment.close()

assert isinstance(segment.cause, dict)
assert len(segment.cause['exceptions']) == 2

0 comments on commit e8c60f2

Please sign in to comment.