Skip to content

Commit

Permalink
Merge pull request #2303 from dhermes/fix-2297
Browse files Browse the repository at this point in the history
Making datastore batch/transaction more robust to failure.
  • Loading branch information
dhermes authored Sep 16, 2016
2 parents 227da50 + 6020ee7 commit 089138e
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 9 deletions.
27 changes: 24 additions & 3 deletions google/cloud/datastore/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,13 @@ def put(self, entity):
:type entity: :class:`google.cloud.datastore.entity.Entity`
:param entity: the entity to be saved.
:raises: ValueError if entity has no key assigned, or if the key's
:raises: :class:`~exceptions.ValueError` if the batch is not in
progress, if entity has no key assigned, or if the key's
``project`` does not match ours.
"""
if self._status != self._IN_PROGRESS:
raise ValueError('Batch must be in progress to put()')

if entity.key is None:
raise ValueError("Entity must have a key")

Expand All @@ -206,9 +210,13 @@ def delete(self, key):
:type key: :class:`google.cloud.datastore.key.Key`
:param key: the key to be deleted.
:raises: ValueError if key is not complete, or if the key's
:raises: :class:`~exceptions.ValueError` if the batch is not in
progress, if key is not complete, or if the key's
``project`` does not match ours.
"""
if self._status != self._IN_PROGRESS:
raise ValueError('Batch must be in progress to delete()')

if key.is_partial:
raise ValueError("Key must be complete")

Expand Down Expand Up @@ -255,7 +263,13 @@ def commit(self):
This is called automatically upon exiting a with statement,
however it can be called explicitly if you don't want to use a
context manager.
:raises: :class:`~exceptions.ValueError` if the batch is not
in progress.
"""
if self._status != self._IN_PROGRESS:
raise ValueError('Batch must be in progress to commit()')

try:
self._commit()
finally:
Expand All @@ -267,12 +281,19 @@ def rollback(self):
Marks the batch as aborted (can't be used again).
Overridden by :class:`google.cloud.datastore.transaction.Transaction`.
:raises: :class:`~exceptions.ValueError` if the batch is not
in progress.
"""
if self._status != self._IN_PROGRESS:
raise ValueError('Batch must be in progress to rollback()')

self._status = self._ABORTED

def __enter__(self):
self._client._push_batch(self)
self.begin()
# NOTE: We make sure begin() succeeds before pushing onto the stack.
self._client._push_batch(self)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/datastore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def put_multi(self, entities):

if not in_batch:
current = self.batch()
current.begin()

for entity in entities:
current.put(entity)
Expand Down Expand Up @@ -384,6 +385,7 @@ def delete_multi(self, keys):

if not in_batch:
current = self.batch()
current.begin()

for key in keys:
current.delete(key)
Expand Down
11 changes: 9 additions & 2 deletions google/cloud/datastore/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class Transaction(Batch):
:param client: the client used to connect to datastore.
"""

_status = None

def __init__(self, client):
super(Transaction, self).__init__(client)
self._id = None
Expand Down Expand Up @@ -125,10 +127,15 @@ def begin(self):
statement, however it can be called explicitly if you don't want
to use a context manager.
:raises: :class:`ValueError` if the transaction has already begun.
:raises: :class:`~exceptions.ValueError` if the transaction has
already begun.
"""
super(Transaction, self).begin()
self._id = self.connection.begin_transaction(self.project)
try:
self._id = self.connection.begin_transaction(self.project)
except:
self._status = self._ABORTED
raise

def rollback(self):
"""Rolls back the current transaction.
Expand Down
73 changes: 72 additions & 1 deletion unit_tests/datastore/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,20 @@ def test_put_entity_wo_key(self):
client = _Client(_PROJECT, connection)
batch = self._makeOne(client)

batch.begin()
self.assertRaises(ValueError, batch.put, _Entity())

def test_put_entity_wrong_status(self):
_PROJECT = 'PROJECT'
connection = _Connection()
client = _Client(_PROJECT, connection)
batch = self._makeOne(client)
entity = _Entity()
entity.key = _Key('OTHER')

self.assertEqual(batch._status, batch._INITIAL)
self.assertRaises(ValueError, batch.put, entity)

def test_put_entity_w_key_wrong_project(self):
_PROJECT = 'PROJECT'
connection = _Connection()
Expand All @@ -78,6 +90,7 @@ def test_put_entity_w_key_wrong_project(self):
entity = _Entity()
entity.key = _Key('OTHER')

batch.begin()
self.assertRaises(ValueError, batch.put, entity)

def test_put_entity_w_partial_key(self):
Expand All @@ -90,6 +103,7 @@ def test_put_entity_w_partial_key(self):
key = entity.key = _Key(_PROJECT)
key._id = None

batch.begin()
batch.put(entity)

mutated_entity = _mutated_pb(self, batch.mutations, 'insert')
Expand All @@ -113,6 +127,7 @@ def test_put_entity_w_completed_key(self):
entity.exclude_from_indexes = ('baz', 'spam')
key = entity.key = _Key(_PROJECT)

batch.begin()
batch.put(entity)

mutated_entity = _mutated_pb(self, batch.mutations, 'upsert')
Expand All @@ -129,6 +144,17 @@ def test_put_entity_w_completed_key(self):
self.assertTrue(spam_values[2].exclude_from_indexes)
self.assertFalse('frotz' in prop_dict)

def test_delete_wrong_status(self):
_PROJECT = 'PROJECT'
connection = _Connection()
client = _Client(_PROJECT, connection)
batch = self._makeOne(client)
key = _Key(_PROJECT)
key._id = None

self.assertEqual(batch._status, batch._INITIAL)
self.assertRaises(ValueError, batch.delete, key)

def test_delete_w_partial_key(self):
_PROJECT = 'PROJECT'
connection = _Connection()
Expand All @@ -137,6 +163,7 @@ def test_delete_w_partial_key(self):
key = _Key(_PROJECT)
key._id = None

batch.begin()
self.assertRaises(ValueError, batch.delete, key)

def test_delete_w_key_wrong_project(self):
Expand All @@ -146,6 +173,7 @@ def test_delete_w_key_wrong_project(self):
batch = self._makeOne(client)
key = _Key('OTHER')

batch.begin()
self.assertRaises(ValueError, batch.delete, key)

def test_delete_w_completed_key(self):
Expand All @@ -155,6 +183,7 @@ def test_delete_w_completed_key(self):
batch = self._makeOne(client)
key = _Key(_PROJECT)

batch.begin()
batch.delete(key)

mutated_key = _mutated_pb(self, batch.mutations, 'delete')
Expand All @@ -180,23 +209,43 @@ def test_rollback(self):
_PROJECT = 'PROJECT'
client = _Client(_PROJECT, None)
batch = self._makeOne(client)
self.assertEqual(batch._status, batch._INITIAL)
batch.begin()
self.assertEqual(batch._status, batch._IN_PROGRESS)
batch.rollback()
self.assertEqual(batch._status, batch._ABORTED)

def test_rollback_wrong_status(self):
_PROJECT = 'PROJECT'
client = _Client(_PROJECT, None)
batch = self._makeOne(client)

self.assertEqual(batch._status, batch._INITIAL)
self.assertRaises(ValueError, batch.rollback)

def test_commit(self):
_PROJECT = 'PROJECT'
connection = _Connection()
client = _Client(_PROJECT, connection)
batch = self._makeOne(client)

self.assertEqual(batch._status, batch._INITIAL)
batch.begin()
self.assertEqual(batch._status, batch._IN_PROGRESS)
batch.commit()
self.assertEqual(batch._status, batch._FINISHED)

self.assertEqual(connection._committed,
[(_PROJECT, batch._commit_request, None)])

def test_commit_wrong_status(self):
_PROJECT = 'PROJECT'
connection = _Connection()
client = _Client(_PROJECT, connection)
batch = self._makeOne(client)

self.assertEqual(batch._status, batch._INITIAL)
self.assertRaises(ValueError, batch.commit)

def test_commit_w_partial_key_entities(self):
_PROJECT = 'PROJECT'
_NEW_ID = 1234
Expand All @@ -209,6 +258,8 @@ def test_commit_w_partial_key_entities(self):
batch._partial_key_entities.append(entity)

self.assertEqual(batch._status, batch._INITIAL)
batch.begin()
self.assertEqual(batch._status, batch._IN_PROGRESS)
batch.commit()
self.assertEqual(batch._status, batch._FINISHED)

Expand Down Expand Up @@ -295,6 +346,26 @@ def test_as_context_mgr_w_error(self):
self.assertEqual(mutated_entity.key, key._key)
self.assertEqual(connection._committed, [])

def test_as_context_mgr_enter_fails(self):
klass = self._getTargetClass()

class FailedBegin(klass):

def begin(self):
raise RuntimeError

client = _Client(None, None)
self.assertEqual(client._batches, [])

batch = FailedBegin(client)
with self.assertRaises(RuntimeError):
# The context manager will never be entered because
# of the failure.
with batch: # pragma: NO COVER
pass
# Make sure no batch was added.
self.assertEqual(client._batches, [])


class _PathElementPB(object):

Expand Down
3 changes: 3 additions & 0 deletions unit_tests/datastore/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ def __init__(self, client):
from google.cloud.datastore.batch import Batch
self._client = client
self._batch = Batch(client)
self._batch.begin()

def __enter__(self):
self._client._push_batch(self._batch)
Expand All @@ -972,10 +973,12 @@ def __exit__(self, *args):
class _NoCommitTransaction(object):

def __init__(self, client, transaction_id='TRANSACTION'):
from google.cloud.datastore.batch import Batch
from google.cloud.datastore.transaction import Transaction
self._client = client
xact = self._transaction = Transaction(client)
xact._id = transaction_id
Batch.begin(xact)

def __enter__(self):
self._client._push_batch(self._transaction)
Expand Down
25 changes: 22 additions & 3 deletions unit_tests/datastore/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ def test_begin_tombstoned(self):

self.assertRaises(ValueError, xact.begin)

def test_begin_w_begin_transaction_failure(self):
_PROJECT = 'PROJECT'
connection = _Connection(234)
client = _Client(_PROJECT, connection)
xact = self._makeOne(client)

connection._side_effect = RuntimeError
with self.assertRaises(RuntimeError):
xact.begin()

self.assertIsNone(xact.id)
self.assertEqual(connection._begun, _PROJECT)

def test_rollback(self):
_PROJECT = 'PROJECT'
connection = _Connection(234)
Expand Down Expand Up @@ -118,10 +131,10 @@ def test_commit_w_partial_keys(self):
connection._completed_keys = [_make_key(_KIND, _ID, _PROJECT)]
client = _Client(_PROJECT, connection)
xact = self._makeOne(client)
xact.begin()
entity = _Entity()
xact.put(entity)
xact._commit_request = commit_request = object()
xact.begin()
xact.commit()
self.assertEqual(connection._committed,
(_PROJECT, commit_request, 234))
Expand Down Expand Up @@ -176,7 +189,10 @@ def _make_key(kind, id_, project):

class _Connection(object):
_marker = object()
_begun = _rolled_back = _committed = None
_begun = None
_rolled_back = None
_committed = None
_side_effect = None

def __init__(self, xact_id=123):
self._xact_id = xact_id
Expand All @@ -185,7 +201,10 @@ def __init__(self, xact_id=123):

def begin_transaction(self, project):
self._begun = project
return self._xact_id
if self._side_effect is None:
return self._xact_id
else:
raise self._side_effect

def rollback(self, project, transaction_id):
self._rolled_back = project, transaction_id
Expand Down

0 comments on commit 089138e

Please sign in to comment.