diff --git a/spanner/google/cloud/spanner/snapshot.py b/spanner/google/cloud/spanner/snapshot.py index 43a4c13ee6c0..3962e866e2bc 100644 --- a/spanner/google/cloud/spanner/snapshot.py +++ b/spanner/google/cloud/spanner/snapshot.py @@ -73,10 +73,14 @@ def read(self, table, columns, keyset, index='', limit=0, :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. - :raises: ValueError for reuse of single-use snapshots. + :raises: ValueError for reuse of single-use snapshots, or if a + transaction ID is pending for multiple-use snapshots. """ - if not self._multi_use and self._read_request_count > 0: - raise ValueError("Cannot re-use single-use snapshot.") + if self._read_request_count > 0: + if not self._multi_use: + raise ValueError("Cannot re-use single-use snapshot.") + if self._transaction_id is None: + raise ValueError("Transaction ID pending.") database = self._session._database api = database.spanner_api @@ -121,10 +125,14 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. - :raises: ValueError for reuse of single-use snapshots. + :raises: ValueError for reuse of single-use snapshots, or if a + transaction ID is pending for multiple-use snapshots. """ - if not self._multi_use and self._read_request_count > 0: - raise ValueError("Cannot re-use single-use snapshot.") + if self._read_request_count > 0: + if not self._multi_use: + raise ValueError("Cannot re-use single-use snapshot.") + if self._transaction_id is None: + raise ValueError("Transaction ID pending.") if params is not None: if param_types is None: diff --git a/spanner/google/cloud/spanner/transaction.py b/spanner/google/cloud/spanner/transaction.py index af2140896830..80896055b89b 100644 --- a/spanner/google/cloud/spanner/transaction.py +++ b/spanner/google/cloud/spanner/transaction.py @@ -32,6 +32,7 @@ def __init__(self, session): super(Transaction, self).__init__(session) self._id = None self._rolled_back = False + self._multi_use = True def _check_state(self): """Helper for :meth:`commit` et al. diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 5a41e18e177e..4717a14c2f24 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -112,7 +112,7 @@ def test_read_grpc_error(self): self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)]) - def _read_helper(self, multi_use, first=True): + def _read_helper(self, multi_use, first=True, count=0): from google.protobuf.struct_pb2 import Struct from google.cloud.proto.spanner.v1.result_set_pb2 import ( PartialResultSet, ResultSetMetadata, ResultSetStats) @@ -156,18 +156,15 @@ def _read_helper(self, multi_use, first=True): session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use + derived._read_request_count = count if not first: derived._transaction_id = TXN_ID - derived._read_request_count = 1 result_set = derived.read( TABLE_NAME, COLUMNS, KEYSET, index=INDEX, limit=LIMIT, resume_token=TOKEN) - if first: - self.assertEqual(derived._read_request_count, 1) - else: - self.assertEqual(derived._read_request_count, 2) + self.assertEqual(derived._read_request_count, count + 1) if multi_use: self.assertIs(result_set._source, derived) @@ -205,14 +202,21 @@ def test_read_wo_multi_use(self): def test_read_wo_multi_use_w_read_request_count_gt_0(self): with self.assertRaises(ValueError): - self._read_helper(multi_use=False, first=False) + self._read_helper(multi_use=False, count=1) def test_read_w_multi_use_wo_first(self): self._read_helper(multi_use=True, first=False) + def test_read_w_multi_use_wo_first_w_count_gt_0(self): + self._read_helper(multi_use=True, first=False, count=1) + def test_read_w_multi_use_w_first(self): self._read_helper(multi_use=True, first=True) + def test_read_w_multi_use_w_first_w_count_gt_0(self): + with self.assertRaises(ValueError): + self._read_helper(multi_use=True, first=True, count=1) + def test_execute_sql_grpc_error(self): from google.cloud.proto.spanner.v1.transaction_pb2 import ( TransactionSelector) @@ -249,7 +253,7 @@ def test_execute_sql_w_params_wo_param_types(self): with self.assertRaises(ValueError): derived.execute_sql(SQL_QUERY_WITH_PARAM, PARAMS) - def _execute_sql_helper(self, multi_use, first=True): + def _execute_sql_helper(self, multi_use, first=True, count=0): from google.protobuf.struct_pb2 import Struct from google.cloud.proto.spanner.v1.result_set_pb2 import ( PartialResultSet, ResultSetMetadata, ResultSetStats) @@ -291,18 +295,15 @@ def _execute_sql_helper(self, multi_use, first=True): session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use + derived._read_request_count = count if not first: derived._transaction_id = TXN_ID - derived._read_request_count = 1 result_set = derived.execute_sql( SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, query_mode=MODE, resume_token=TOKEN) - if first: - self.assertEqual(derived._read_request_count, 1) - else: - self.assertEqual(derived._read_request_count, 2) + self.assertEqual(derived._read_request_count, count + 1) if multi_use: self.assertIs(result_set._source, derived) @@ -341,13 +342,20 @@ def test_execute_sql_wo_multi_use(self): def test_execute_sql_wo_multi_use_w_read_request_count_gt_0(self): with self.assertRaises(ValueError): - self._execute_sql_helper(multi_use=False, first=False) + self._execute_sql_helper(multi_use=False, count=1) + + def test_execute_sql_w_multi_use_wo_first(self): + self._execute_sql_helper(multi_use=True, first=False) + + def test_execute_sql_w_multi_use_wo_first_w_count_gt_0(self): + self._execute_sql_helper(multi_use=True, first=False, count=1) def test_execute_sql_w_multi_use_w_first(self): self._execute_sql_helper(multi_use=True, first=True) - def test_execute_sql_w_multi_use_wo_first(self): - self._execute_sql_helper(multi_use=True, first=False) + def test_execute_sql_w_multi_use_w_first_w_count_gt_0(self): + with self.assertRaises(ValueError): + self._execute_sql_helper(multi_use=True, first=True, count=1) class _MockCancellableIterator(object): diff --git a/spanner/tests/unit/test_transaction.py b/spanner/tests/unit/test_transaction.py index 997f4d5153c8..7dd57242167f 100644 --- a/spanner/tests/unit/test_transaction.py +++ b/spanner/tests/unit/test_transaction.py @@ -51,7 +51,8 @@ def test_ctor_defaults(self): self.assertIs(transaction._session, session) self.assertIsNone(transaction._id) self.assertIsNone(transaction.committed) - self.assertEqual(transaction._rolled_back, False) + self.assertFalse(transaction._rolled_back) + self.assertTrue(transaction._multi_use) def test__check_state_not_begun(self): session = _Session()