Skip to content

Commit

Permalink
Add guard against pending transaction for multi-use snapshots.
Browse files Browse the repository at this point in the history
  • Loading branch information
tseaver committed Jul 24, 2017
1 parent ec38b8b commit 8244af9
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 23 deletions.
20 changes: 14 additions & 6 deletions spanner/google/cloud/spanner/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,14 @@ def read(self, table, columns, keyset, index='', limit=0,
:rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
:raises: ValueError for reuse of single-use snapshots.
:raises: ValueError for reuse of single-use snapshots, or if a
transaction ID is pending for multiple-use snapshots.
"""
if not self._multi_use and self._read_request_count > 0:
raise ValueError("Cannot re-use single-use snapshot.")
if self._read_request_count > 0:
if not self._multi_use:
raise ValueError("Cannot re-use single-use snapshot.")
if self._transaction_id is None:
raise ValueError("Transaction ID pending.")

database = self._session._database
api = database.spanner_api
Expand Down Expand Up @@ -121,10 +125,14 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None,
:rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
:raises: ValueError for reuse of single-use snapshots.
:raises: ValueError for reuse of single-use snapshots, or if a
transaction ID is pending for multiple-use snapshots.
"""
if not self._multi_use and self._read_request_count > 0:
raise ValueError("Cannot re-use single-use snapshot.")
if self._read_request_count > 0:
if not self._multi_use:
raise ValueError("Cannot re-use single-use snapshot.")
if self._transaction_id is None:
raise ValueError("Transaction ID pending.")

if params is not None:
if param_types is None:
Expand Down
1 change: 1 addition & 0 deletions spanner/google/cloud/spanner/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, session):
super(Transaction, self).__init__(session)
self._id = None
self._rolled_back = False
self._multi_use = True

def _check_state(self):
"""Helper for :meth:`commit` et al.
Expand Down
40 changes: 24 additions & 16 deletions spanner/tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_read_grpc_error(self):
self.assertEqual(options.kwargs['metadata'],
[('google-cloud-resource-prefix', database.name)])

def _read_helper(self, multi_use, first=True):
def _read_helper(self, multi_use, first=True, count=0):
from google.protobuf.struct_pb2 import Struct
from google.cloud.proto.spanner.v1.result_set_pb2 import (
PartialResultSet, ResultSetMetadata, ResultSetStats)
Expand Down Expand Up @@ -156,18 +156,15 @@ def _read_helper(self, multi_use, first=True):
session = _Session(database)
derived = self._makeDerived(session)
derived._multi_use = multi_use
derived._read_request_count = count
if not first:
derived._transaction_id = TXN_ID
derived._read_request_count = 1

result_set = derived.read(
TABLE_NAME, COLUMNS, KEYSET,
index=INDEX, limit=LIMIT, resume_token=TOKEN)

if first:
self.assertEqual(derived._read_request_count, 1)
else:
self.assertEqual(derived._read_request_count, 2)
self.assertEqual(derived._read_request_count, count + 1)

if multi_use:
self.assertIs(result_set._source, derived)
Expand Down Expand Up @@ -205,14 +202,21 @@ def test_read_wo_multi_use(self):

def test_read_wo_multi_use_w_read_request_count_gt_0(self):
with self.assertRaises(ValueError):
self._read_helper(multi_use=False, first=False)
self._read_helper(multi_use=False, count=1)

def test_read_w_multi_use_wo_first(self):
self._read_helper(multi_use=True, first=False)

def test_read_w_multi_use_wo_first_w_count_gt_0(self):
self._read_helper(multi_use=True, first=False, count=1)

def test_read_w_multi_use_w_first(self):
self._read_helper(multi_use=True, first=True)

def test_read_w_multi_use_w_first_w_count_gt_0(self):
with self.assertRaises(ValueError):
self._read_helper(multi_use=True, first=True, count=1)

def test_execute_sql_grpc_error(self):
from google.cloud.proto.spanner.v1.transaction_pb2 import (
TransactionSelector)
Expand Down Expand Up @@ -249,7 +253,7 @@ def test_execute_sql_w_params_wo_param_types(self):
with self.assertRaises(ValueError):
derived.execute_sql(SQL_QUERY_WITH_PARAM, PARAMS)

def _execute_sql_helper(self, multi_use, first=True):
def _execute_sql_helper(self, multi_use, first=True, count=0):
from google.protobuf.struct_pb2 import Struct
from google.cloud.proto.spanner.v1.result_set_pb2 import (
PartialResultSet, ResultSetMetadata, ResultSetStats)
Expand Down Expand Up @@ -291,18 +295,15 @@ def _execute_sql_helper(self, multi_use, first=True):
session = _Session(database)
derived = self._makeDerived(session)
derived._multi_use = multi_use
derived._read_request_count = count
if not first:
derived._transaction_id = TXN_ID
derived._read_request_count = 1

result_set = derived.execute_sql(
SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES,
query_mode=MODE, resume_token=TOKEN)

if first:
self.assertEqual(derived._read_request_count, 1)
else:
self.assertEqual(derived._read_request_count, 2)
self.assertEqual(derived._read_request_count, count + 1)

if multi_use:
self.assertIs(result_set._source, derived)
Expand Down Expand Up @@ -341,13 +342,20 @@ def test_execute_sql_wo_multi_use(self):

def test_execute_sql_wo_multi_use_w_read_request_count_gt_0(self):
with self.assertRaises(ValueError):
self._execute_sql_helper(multi_use=False, first=False)
self._execute_sql_helper(multi_use=False, count=1)

def test_execute_sql_w_multi_use_wo_first(self):
self._execute_sql_helper(multi_use=True, first=False)

def test_execute_sql_w_multi_use_wo_first_w_count_gt_0(self):
self._execute_sql_helper(multi_use=True, first=False, count=1)

def test_execute_sql_w_multi_use_w_first(self):
self._execute_sql_helper(multi_use=True, first=True)

def test_execute_sql_w_multi_use_wo_first(self):
self._execute_sql_helper(multi_use=True, first=False)
def test_execute_sql_w_multi_use_w_first_w_count_gt_0(self):
with self.assertRaises(ValueError):
self._execute_sql_helper(multi_use=True, first=True, count=1)


class _MockCancellableIterator(object):
Expand Down
3 changes: 2 additions & 1 deletion spanner/tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def test_ctor_defaults(self):
self.assertIs(transaction._session, session)
self.assertIsNone(transaction._id)
self.assertIsNone(transaction.committed)
self.assertEqual(transaction._rolled_back, False)
self.assertFalse(transaction._rolled_back)
self.assertTrue(transaction._multi_use)

def test__check_state_not_begun(self):
session = _Session()
Expand Down

0 comments on commit 8244af9

Please sign in to comment.