Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor transactions to use their own event loops #443

Merged
merged 4 commits into from
May 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions google/cloud/ndb/_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

"""Support for batching operations."""

from google.cloud.ndb import _eventloop


def get_batch(batch_cls, options=None):
"""Gets a data structure for storing batched calls to Datastore Lookup.
Expand Down Expand Up @@ -68,5 +66,5 @@ def idle():
return idle

batches[options_key] = batch = batch_cls(options)
_eventloop.add_idle(idler(batch))
context.eventloop.add_idle(idler(batch))
return batch
24 changes: 17 additions & 7 deletions google/cloud/ndb/_datastore_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,19 @@ def commit_callback(rpc):
rpc.add_done_callback(commit_callback)


def prepare_to_commit(transaction):
"""Signal that we're ready to commit a transaction.

Currently just used to signal to the commit batch that we're not going to
need to call `AllocateIds`, because we're ready to commit now.

Args:
transaction (bytes): The transaction id about to be committed.
"""
batch = _get_commit_batch(transaction, _options.Options())
batch.preparing_to_commit = True


def commit(transaction, retries=None, timeout=None):
"""Commit a transaction.

Expand Down Expand Up @@ -605,6 +618,7 @@ def __init__(self, transaction, options):
self.allocating_ids = []
self.incomplete_mutations = []
self.incomplete_futures = []
self.preparing_to_commit = False

def put(self, entity_pb):
"""Add an entity to batch to be stored.
Expand Down Expand Up @@ -657,8 +671,9 @@ def delete(self, key):

def idle_callback(self):
"""Call AllocateIds on any incomplete keys in the batch."""
if not self.incomplete_mutations:
# This will happen if `commit` is called first.
# If there are no incomplete mutations, or if we're already preparing
# to commit, there's no need to allocate ids.
if self.preparing_to_commit or not self.incomplete_mutations:
return

# Signal to a future commit that there is an id allocation in
Expand Down Expand Up @@ -728,11 +743,6 @@ def commit(self, retries=None, timeout=None):
if not future.done():
yield future

# Head off making any more AllocateId calls. Any remaining incomplete
# keys will get ids as part of the Commit call.
self.incomplete_mutations = []
self.incomplete_futures = []

future = tasklets.Future("Commit")
futures = self.futures

Expand Down
9 changes: 1 addition & 8 deletions google/cloud/ndb/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
"queue_call",
"queue_rpc",
"run",
"run0",
"run1",
]

Expand Down Expand Up @@ -396,13 +395,7 @@ def run():
loop.run()


def run0():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't used.

"""Calls :method:`EventLoop.run0` on current event loop."""
loop = get_event_loop()
loop.run0()


def run1():
"""Calls :method:`EventLoop.run1` on current event loop."""
loop = get_event_loop()
loop.run1()
return loop.run1()
23 changes: 22 additions & 1 deletion google/cloud/ndb/_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,36 @@ def _transaction_async(context, callback, read_only=False):
tx_context = context.new(
transaction=transaction_id,
on_commit_callbacks=on_commit_callbacks,
cache=None, # Use new, empty cache for transaction
batches=None,
commit_batches=None,
cache=None,
# We could just pass `None` here and let the `Context` constructor
# instantiate a new event loop, but our unit tests inject a subclass of
# `EventLoop` that makes testing a little easier. This makes sure the
# new event loop is of the same type as the current one, to propagate
# the event loop class used for testing.
eventloop=type(context.eventloop)(),
)

# The outer loop is dependent on the inner loop
def run_inner_loop(inner_context):
with inner_context.use():
if inner_context.eventloop.run1():
return True # schedule again

context.eventloop.add_idle(run_inner_loop, tx_context)

with tx_context.use():
try:
# Run the callback
result = callback()
if isinstance(result, tasklets.Future):
result = yield result

# Make sure we've run everything we can run before calling commit
_datastore_api.prepare_to_commit(transaction_id)
tx_context.eventloop.run()

# Commit the transaction
yield _datastore_api.commit(transaction_id, retries=0)

Expand Down
20 changes: 15 additions & 5 deletions google/cloud/ndb/tasklets.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def wait(self):
after a call to this method.
"""
while not self._done:
_eventloop.run1()
if not _eventloop.run1():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanity check.

raise RuntimeError(
"Eventloop is exhausted with unfinished futures."
)

def check_success(self):
"""Check whether a future has completed without raising an exception.
Expand Down Expand Up @@ -348,16 +351,20 @@ def done_callback(yielded):

error = yielded.exception()
if error:
_eventloop.call_soon(self._advance_tasklet, error=error)
self.context.eventloop.call_soon(
self._advance_tasklet, error=error
)
else:
_eventloop.call_soon(self._advance_tasklet, yielded.result())
self.context.eventloop.call_soon(
self._advance_tasklet, yielded.result()
)

if isinstance(yielded, Future):
yielded.add_done_callback(done_callback)
self.waiting_on = yielded

elif isinstance(yielded, _remote.RemoteCall):
_eventloop.queue_rpc(yielded, done_callback)
self.context.eventloop.queue_rpc(yielded, done_callback)
self.waiting_on = yielded

elif isinstance(yielded, (list, tuple)):
Expand Down Expand Up @@ -515,7 +522,10 @@ def wait_any(futures):
if future.done():
return future

_eventloop.run1()
if not _eventloop.run1():
raise RuntimeError(
"Eventloop is exhausted with unfinished futures."
)


def wait_all(futures):
Expand Down
22 changes: 22 additions & 0 deletions tests/system/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import operator
import time

KIND = "SomeKind"
Expand Down Expand Up @@ -61,3 +63,23 @@ def eventually(f, predicate, timeout=120, interval=2):
time.sleep(interval)

assert predicate(value)


def length_equals(n):
"""Returns predicate that returns True if passed a sequence of length `n`.

For use with `eventually`.
"""

def predicate(sequence):
return len(sequence) == n

return predicate


def equals(n):
"""Returns predicate that returns True if passed `n`.

For use with `eventually`.
"""
return functools.partial(operator.eq, n)
33 changes: 25 additions & 8 deletions tests/system/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
System tests for Create, Update, Delete. (CRUD)
"""
import datetime
import functools
import operator
import os
import pickle
import random
Expand All @@ -37,15 +35,11 @@
from google.cloud.ndb import _cache
from google.cloud.ndb import global_cache as global_cache_module

from tests.system import KIND, eventually
from tests.system import KIND, eventually, equals

USE_REDIS_CACHE = bool(os.environ.get("REDIS_CACHE_URL"))


def _equals(n):
return functools.partial(operator.eq, n)


@pytest.mark.usefixtures("client_context")
def test_retrieve_entity(ds_entity):
entity_id = test_utils.system.unique_resource_id()
Expand Down Expand Up @@ -526,7 +520,7 @@ class SomeKind(ndb.Model):
# Sneaky. Delete entity out from under cache so we know we're getting
# cached copy.
key.delete()
eventually(key.get, _equals(None))
eventually(key.get, equals(None))

retrieved = key.get()
assert retrieved.foo == 42
Expand Down Expand Up @@ -772,6 +766,29 @@ def delete_entity():
assert key.get() is None


def test_delete_entity_in_transaction_with_global_cache(
client_context, ds_entity
):
"""Regression test for #426

https://github.com/googleapis/python-ndb/issues/426
"""

class SomeKind(ndb.Model):
foo = ndb.IntegerProperty()

entity_id = test_utils.system.unique_resource_id()
ds_entity(KIND, entity_id, foo=42)

global_cache = global_cache_module._InProcessGlobalCache()
with client_context.new(global_cache=global_cache).use():
key = ndb.Key(KIND, entity_id)
assert key.get().foo == 42

ndb.transaction(key.delete)
assert key.get() is None


@pytest.mark.usefixtures("client_context")
def test_delete_entity_in_transaction_then_rollback(ds_entity):
entity_id = test_utils.system.unique_resource_id()
Expand Down
24 changes: 24 additions & 0 deletions tests/system/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from google.cloud import ndb

from tests.system import eventually, length_equals

USE_REDIS_CACHE = bool(os.environ.get("REDIS_CACHE_URL"))


Expand Down Expand Up @@ -271,3 +273,25 @@ def update(id, add, fail=False):

entity = SomeKind.get_by_id(id)
assert entity.foo == 142


@pytest.mark.usefixtures("client_context")
def test_insert_entity_in_transaction_without_preallocating_id(dispose_of):
class SomeKind(ndb.Model):
foo = ndb.IntegerProperty()
bar = ndb.StringProperty()

def save_entity():
# By not waiting on the Future, we don't force a call to AllocateIds
# before the transaction is committed.
SomeKind(foo=42, bar="none").put_async()

ndb.transaction(save_entity)

query = SomeKind.query()
eventually(query.fetch, length_equals(1))
retrieved = query.fetch()[0]
dispose_of(retrieved._key._key)

assert retrieved.foo == 42
assert retrieved.bar == "none"
Loading