Skip to content

Commit

Permalink
Implement Context.call_on_commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Rossi committed Aug 15, 2019
1 parent 864aefe commit 4345ca6
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/google/cloud/ndb/_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def _transaction_async(context, callback, read_only=False):
read_only, retries=0
)

with context.new(transaction=transaction_id).use() as tx_context:
on_commit_callbacks = []
tx_context = context.new(
transaction=transaction_id, on_commit_callbacks=on_commit_callbacks
)
with tx_context.use():
try:
# Run the callback
result = callback()
Expand All @@ -115,6 +119,8 @@ def _transaction_async(context, callback, read_only=False):
raise

tx_context._clear_global_cache()
for callback in on_commit_callbacks:
callback()

return result

Expand Down
8 changes: 7 additions & 1 deletion src/google/cloud/ndb/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def policy(key):
"transaction",
"cache",
"global_cache",
"on_commit_callbacks",
],
)

Expand Down Expand Up @@ -178,6 +179,7 @@ def __new__(
global_cache_policy=None,
global_cache_timeout_policy=None,
datastore_policy=None,
on_commit_callbacks=None,
):
if eventloop is None:
eventloop = _eventloop.EventLoop()
Expand Down Expand Up @@ -207,6 +209,7 @@ def __new__(
transaction=transaction,
cache=new_cache,
global_cache=global_cache,
on_commit_callbacks=on_commit_callbacks,
)

context.set_cache_policy(cache_policy)
Expand Down Expand Up @@ -468,7 +471,10 @@ def call_on_commit(self, callback):
Args:
callback (Callable): The callback function.
"""
raise NotImplementedError
if self.in_transaction():
self.on_commit_callbacks.append(callback)
else:
callback()

def in_transaction(self):
"""Get whether a transaction is currently active.
Expand Down
4 changes: 4 additions & 0 deletions tests/system/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,14 @@ class SomeKind(ndb.Model):

@pytest.mark.usefixtures("client_context")
def test_insert_entity_in_transaction(dispose_of):
commit_callback = mock.Mock()

class SomeKind(ndb.Model):
foo = ndb.IntegerProperty()
bar = ndb.StringProperty()

def save_entity():
ndb.get_context().call_on_commit(commit_callback)
entity = SomeKind(foo=42, bar="none")
key = entity.put()
dispose_of(key._key)
Expand All @@ -420,6 +423,7 @@ def save_entity():
retrieved = key.get()
assert retrieved.foo == 42
assert retrieved.bar == "none"
commit_callback.assert_called_once_with()


@pytest.mark.usefixtures("client_context")
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test__transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest

from google.api_core import exceptions as core_exceptions
from google.cloud.ndb import context as context_module
from google.cloud.ndb import exceptions
from google.cloud.ndb import tasklets
from google.cloud.ndb import _transaction
Expand Down Expand Up @@ -75,7 +76,10 @@ class Test_transaction_async:
@pytest.mark.usefixtures("in_context")
@mock.patch("google.cloud.ndb._transaction._datastore_api")
def test_success(_datastore_api):
on_commit_callback = mock.Mock()

def callback():
context_module.get_context().call_on_commit(on_commit_callback)
return "I tried, momma."

begin_future = tasklets.Future("begin transaction")
Expand All @@ -95,6 +99,7 @@ def callback():
commit_future.set_result(None)

assert future.result() == "I tried, momma."
on_commit_callback.assert_called_once_with()

@staticmethod
@pytest.mark.usefixtures("in_context")
Expand Down
14 changes: 12 additions & 2 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,18 @@ class SomeKind(model.Model):

def test_call_on_commit(self):
context = self._make_one()
with pytest.raises(NotImplementedError):
context.call_on_commit(None)
callback = mock.Mock()
context.call_on_commit(callback)
callback.assert_called_once_with()

def test_call_on_commit_with_transaction(self):
callbacks = []
callback = "himom!"
context = self._make_one(
transaction=b"tx123", on_commit_callbacks=callbacks
)
context.call_on_commit(callback)
assert context.on_commit_callbacks == ["himom!"]

def test_in_transaction(self):
context = self._make_one()
Expand Down

0 comments on commit 4345ca6

Please sign in to comment.