diff --git a/spanner/google/cloud/spanner/session.py b/spanner/google/cloud/spanner/session.py index 04fcacea38ee8..33a1a8b2838b2 100644 --- a/spanner/google/cloud/spanner/session.py +++ b/spanner/google/cloud/spanner/session.py @@ -24,6 +24,7 @@ # pylint: disable=ungrouped-imports from google.cloud.exceptions import NotFound +from google.cloud.exceptions import GrpcRendezvous from google.cloud.spanner._helpers import _options_with_prefix from google.cloud.spanner.batch import Batch from google.cloud.spanner.snapshot import Snapshot @@ -286,7 +287,7 @@ def run_in_transaction(self, func, *args, **kw): txn.begin() try: return_value = func(txn, *args, **kw) - except GaxError as exc: + except (GaxError, GrpcRendezvous) as exc: _delay_until_retry(exc, deadline) del self._transaction continue @@ -318,7 +319,12 @@ def _delay_until_retry(exc, deadline): :type deadline: float :param deadline: maximum timestamp to continue retrying the transaction. """ - if exc_to_code(exc.cause) != StatusCode.ABORTED: + if isinstance(exc, GrpcRendezvous): # pragma: NO COVER see #3663 + cause = exc + else: + cause = exc.cause + + if exc_to_code(cause) != StatusCode.ABORTED: raise now = time.time() @@ -326,7 +332,7 @@ def _delay_until_retry(exc, deadline): if now >= deadline: raise - delay = _get_retry_delay(exc) + delay = _get_retry_delay(cause) if delay is not None: if now + delay > deadline: @@ -336,7 +342,7 @@ def _delay_until_retry(exc, deadline): # pylint: enable=misplaced-bare-raise -def _get_retry_delay(exc): +def _get_retry_delay(cause): """Helper for :func:`_delay_until_retry`. :type exc: :class:`google.gax.errors.GaxError` @@ -345,7 +351,7 @@ def _get_retry_delay(exc): :rtype: float :returns: seconds to wait before retrying the transaction. """ - metadata = dict(exc.cause.trailing_metadata()) + metadata = dict(cause.trailing_metadata()) retry_info_pb = metadata.get('google.rpc.retryinfo-bin') if retry_info_pb is not None: retry_info = RetryInfo() diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index f5d15d715ed51..fa70573c88deb 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -57,6 +57,8 @@ 'google-cloud-python-systest') DATABASE_ID = 'test_database' EXISTING_INSTANCES = [] +COUNTERS_TABLE = 'counters' +COUNTERS_COLUMNS = ('name', 'value') class Config(object): @@ -360,11 +362,6 @@ class TestSessionAPI(unittest.TestCase, _TestData): 'description', 'exactly_hwhen', ) - COUNTERS_TABLE = 'counters' - COUNTERS_COLUMNS = ( - 'name', - 'value', - ) SOME_DATE = datetime.date(2011, 1, 17) SOME_TIME = datetime.datetime(1989, 1, 17, 17, 59, 12, 345612) NANO_TIME = TimestampWithNanoseconds(1995, 8, 31, nanosecond=987654321) @@ -554,9 +551,7 @@ def _transaction_concurrency_helper(self, unit_of_work, pkey): with session.batch() as batch: batch.insert_or_update( - self.COUNTERS_TABLE, - self.COUNTERS_COLUMNS, - [[pkey, INITIAL_VALUE]]) + COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, INITIAL_VALUE]]) # We don't want to run the threads' transactions in the current # session, which would fail. @@ -582,7 +577,7 @@ def _transaction_concurrency_helper(self, unit_of_work, pkey): keyset = KeySet(keys=[(pkey,)]) rows = list(session.read( - self.COUNTERS_TABLE, self.COUNTERS_COLUMNS, keyset)) + COUNTERS_TABLE, COUNTERS_COLUMNS, keyset)) self.assertEqual(len(rows), 1) _, value = rows[0] self.assertEqual(value, INITIAL_VALUE + len(threads)) @@ -590,13 +585,11 @@ def _transaction_concurrency_helper(self, unit_of_work, pkey): def _read_w_concurrent_update(self, transaction, pkey): keyset = KeySet(keys=[(pkey,)]) rows = list(transaction.read( - self.COUNTERS_TABLE, self.COUNTERS_COLUMNS, keyset)) + COUNTERS_TABLE, COUNTERS_COLUMNS, keyset)) self.assertEqual(len(rows), 1) pkey, value = rows[0] transaction.update( - self.COUNTERS_TABLE, - self.COUNTERS_COLUMNS, - [[pkey, value + 1]]) + COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, value + 1]]) def test_transaction_read_w_concurrent_updates(self): PKEY = 'read_w_concurrent_updates' @@ -613,15 +606,48 @@ def _query_w_concurrent_update(self, transaction, pkey): self.assertEqual(len(rows), 1) pkey, value = rows[0] transaction.update( - self.COUNTERS_TABLE, - self.COUNTERS_COLUMNS, - [[pkey, value + 1]]) + COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, value + 1]]) def test_transaction_query_w_concurrent_updates(self): PKEY = 'query_w_concurrent_updates' self._transaction_concurrency_helper( self._query_w_concurrent_update, PKEY) + def test_transaction_read_w_abort(self): + + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + + trigger = _ReadAbortTrigger() + + with session.batch() as batch: + batch.delete(COUNTERS_TABLE, self.ALL) + batch.insert( + COUNTERS_TABLE, + COUNTERS_COLUMNS, + [[trigger.KEY1, 0], [trigger.KEY2, 0]]) + + provoker = threading.Thread( + target=trigger.provoke_abort, args=(self._db,)) + handler = threading.Thread( + target=trigger.handle_abort, args=(self._db,)) + + provoker.start() + trigger.provoker_started.wait() + + handler.start() + trigger.handler_done.wait() + + provoker.join() + handler.join() + + rows = list(session.read(COUNTERS_TABLE, COUNTERS_COLUMNS, self.ALL)) + self._check_row_data( + rows, expected=[[trigger.KEY1, 1], [trigger.KEY2, 1]]) + @staticmethod def _row_data(max_index): for index in range(max_index): @@ -1103,3 +1129,64 @@ def __init__(self, db): def delete(self): self._db.drop() + + +class _ReadAbortTrigger(object): + """Helper for tests provoking abort-during-read.""" + + KEY1 = 'key1' + KEY2 = 'key2' + + def __init__(self): + self.provoker_started = threading.Event() + self.provoker_done = threading.Event() + self.handler_running = threading.Event() + self.handler_done = threading.Event() + + def _provoke_abort_unit_of_work(self, transaction): + keyset = KeySet(keys=[(self.KEY1,)]) + rows = list( + transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset)) + + assert len(rows) == 1 + row = rows[0] + value = row[1] + + self.provoker_started.set() + + self.handler_running.wait() + + transaction.update( + COUNTERS_TABLE, COUNTERS_COLUMNS, [[self.KEY1, value + 1]]) + + def provoke_abort(self, database): + database.run_in_transaction(self._provoke_abort_unit_of_work) + self.provoker_done.set() + + def _handle_abort_unit_of_work(self, transaction): + keyset_1 = KeySet(keys=[(self.KEY1,)]) + rows_1 = list( + transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset_1)) + + assert len(rows_1) == 1 + row_1 = rows_1[0] + value_1 = row_1[1] + + self.handler_running.set() + + self.provoker_done.wait() + + keyset_2 = KeySet(keys=[(self.KEY2,)]) + rows_2 = list( + transaction.read(COUNTERS_TABLE, COUNTERS_COLUMNS, keyset_2)) + + assert len(rows_2) == 1 + row_2 = rows_2[0] + value_2 = row_2[1] + + transaction.update( + COUNTERS_TABLE, COUNTERS_COLUMNS, [[self.KEY2, value_1 + value_2]]) + + def handle_abort(self, database): + database.run_in_transaction(self._handle_abort_unit_of_work) + self.handler_done.set()