Skip to content

Commit

Permalink
fix: fix a connection leak in RedisCache (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Rossi authored Oct 8, 2020
1 parent 758c8e6 commit 47ae172
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 5 deletions.
32 changes: 31 additions & 1 deletion google/cloud/ndb/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,37 @@ def make_call(self):

def future_info(self, key):
"""Generate info string for Future."""
return "GlobalWatch.delete({})".format(key)
return "GlobalCache.watch({})".format(key)


def global_unwatch(key):
"""End optimistic transaction with global cache.
Indicates that value for the key wasn't found in the database, so there will not be
a future call to :func:`global_compare_and_swap`, and we no longer need to watch
this key.
Args:
key (bytes): The key to unwatch.
Returns:
tasklets.Future: Eventual result will be ``None``.
"""
batch = _batch.get_batch(_GlobalCacheUnwatchBatch)
return batch.add(key)


class _GlobalCacheUnwatchBatch(_GlobalCacheWatchBatch):
"""Batch for global cache unwatch requests. """

def make_call(self):
"""Call :method:`GlobalCache.unwatch`."""
cache = context_module.get_context().global_cache
return cache.unwatch(self.keys)

def future_info(self, key):
"""Generate info string for Future."""
return "GlobalCache.unwatch({})".format(key)


def global_compare_and_swap(key, value, expires=None):
Expand Down
13 changes: 9 additions & 4 deletions google/cloud/ndb/_datastore_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,15 @@ def lookup(key, options):
entity_pb = yield batch.add(key)

# Do not cache misses
if use_global_cache and not key_locked and entity_pb is not _NOT_FOUND:
expires = context._global_cache_timeout(key, options)
serialized = entity_pb.SerializeToString()
yield _cache.global_compare_and_swap(cache_key, serialized, expires=expires)
if use_global_cache and not key_locked:
if entity_pb is not _NOT_FOUND:
expires = context._global_cache_timeout(key, options)
serialized = entity_pb.SerializeToString()
yield _cache.global_compare_and_swap(
cache_key, serialized, expires=expires
)
else:
yield _cache.global_unwatch(cache_key)

raise tasklets.Return(entity_pb)

Expand Down
32 changes: 32 additions & 0 deletions google/cloud/ndb/global_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ def watch(self, keys):
"""
raise NotImplementedError

@abc.abstractmethod
def unwatch(self, keys):
"""End an optimistic transaction for the given keys.
Indicates that value for the key wasn't found in the database, so there will not
be a future call to :meth:`compare_and_swap`, and we no longer need to watch
this key.
Arguments:
keys (List[bytes]): The keys to watch.
"""
raise NotImplementedError

@abc.abstractmethod
def compare_and_swap(self, items, expires=None):
"""Like :meth:`set` but using an optimistic transaction.
Expand Down Expand Up @@ -160,6 +173,11 @@ def watch(self, keys):
for key in keys:
self._watch_keys[key] = self.cache.get(key)

def unwatch(self, keys):
"""Implements :meth:`GlobalCache.unwatch`."""
for key in keys:
self._watch_keys.pop(key, None)

def compare_and_swap(self, items, expires=None):
"""Implements :meth:`GlobalCache.compare_and_swap`."""
if expires:
Expand Down Expand Up @@ -239,6 +257,13 @@ def watch(self, keys):
for key in keys:
self.pipes[key] = holder

def unwatch(self, keys):
"""Implements :meth:`GlobalCache.watch`."""
for key in keys:
holder = self.pipes.pop(key, None)
if holder:
holder.pipe.reset()

def compare_and_swap(self, items, expires=None):
"""Implements :meth:`GlobalCache.compare_and_swap`."""
pipes = {}
Expand Down Expand Up @@ -391,6 +416,13 @@ def watch(self, keys):
for key, (value, caskey) in self.client.gets_many(keys).items():
caskeys[key] = caskey

def unwatch(self, keys):
"""Implements :meth:`GlobalCache.unwatch`."""
keys = [self._key(key) for key in keys]
caskeys = self.caskeys
for key in keys:
caskeys.pop(key, None)

def compare_and_swap(self, items, expires=None):
"""Implements :meth:`GlobalCache.compare_and_swap`."""
caskeys = self.caskeys
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test__cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,31 @@ def test_add_and_idle_and_done_callbacks(in_context):
assert future2.result() is None


@mock.patch("google.cloud.ndb._cache._batch")
def test_global_unwatch(_batch):
batch = _batch.get_batch.return_value
assert _cache.global_unwatch(b"key") is batch.add.return_value
_batch.get_batch.assert_called_once_with(_cache._GlobalCacheUnwatchBatch)
batch.add.assert_called_once_with(b"key")


class Test_GlobalCacheUnwatchBatch:
@staticmethod
def test_add_and_idle_and_done_callbacks(in_context):
cache = mock.Mock()

batch = _cache._GlobalCacheUnwatchBatch({})
future1 = batch.add(b"foo")
future2 = batch.add(b"bar")

with in_context.new(global_cache=cache).use():
batch.idle_callback()

cache.unwatch.assert_called_once_with([b"foo", b"bar"])
assert future1.result() is None
assert future2.result() is None


class Test_global_compare_and_swap:
@staticmethod
@mock.patch("google.cloud.ndb._cache._batch")
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test__datastore_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class SomeKind(model.Model):
assert future.result() is _api._NOT_FOUND

assert global_cache.get([cache_key]) == [_cache._LOCKED]
assert len(global_cache._watch_keys) == 0


class Test_LookupBatch:
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/test_global_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def delete(self, keys):
def watch(self, keys):
return super(MockImpl, self).watch(keys)

def unwatch(self, keys):
return super(MockImpl, self).unwatch(keys)

def compare_and_swap(self, items, expires=None):
return super(MockImpl, self).compare_and_swap(items, expires=expires)

Expand All @@ -63,6 +66,11 @@ def test_watch(self):
with pytest.raises(NotImplementedError):
cache.watch(b"foo")

def test_unwatch(self):
cache = self.make_one()
with pytest.raises(NotImplementedError):
cache.unwatch(b"foo")

def test_compare_and_swap(self):
cache = self.make_one()
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -147,6 +155,16 @@ def test_watch_compare_and_swap_with_expires(time):
result = cache.get([b"one", b"two", b"three"])
assert result == [None, b"hamburgers", None]

@staticmethod
def test_watch_unwatch():
cache = global_cache._InProcessGlobalCache()
result = cache.watch([b"one", b"two", b"three"])
assert result is None

result = cache.unwatch([b"one", b"two", b"three"])
assert result is None
assert cache._watch_keys == {}


class TestRedisCache:
@staticmethod
Expand Down Expand Up @@ -225,6 +243,23 @@ def test_watch(uuid):
"bar": global_cache._Pipeline(pipe, "abc123"),
}

@staticmethod
def test_unwatch():
redis = mock.Mock(spec=())
cache = global_cache.RedisCache(redis)
pipe1 = mock.Mock(spec=("reset",))
pipe2 = mock.Mock(spec=("reset",))
cache._pipes.pipes = {
"ay": global_cache._Pipeline(pipe1, "abc123"),
"be": global_cache._Pipeline(pipe1, "abc123"),
"see": global_cache._Pipeline(pipe2, "def456"),
"dee": global_cache._Pipeline(pipe2, "def456"),
"whatevs": global_cache._Pipeline(None, "himom!"),
}

cache.unwatch(["ay", "be", "see", "dee", "nuffin"])
assert cache.pipes == {"whatevs": global_cache._Pipeline(None, "himom!")}

@staticmethod
def test_compare_and_swap():
redis = mock.Mock(spec=())
Expand Down Expand Up @@ -450,6 +485,17 @@ def test_watch():
key2: b"1",
}

@staticmethod
def test_unwatch():
client = mock.Mock(spec=())
cache = global_cache.MemcacheCache(client)
key2 = cache._key(b"two")
cache.caskeys[key2] = b"5"
cache.caskeys["whatevs"] = b"6"
cache.unwatch([b"one", b"two"])

assert cache.caskeys == {"whatevs": b"6"}

@staticmethod
def test_compare_and_swap():
client = mock.Mock(spec=("cas",))
Expand Down

0 comments on commit 47ae172

Please sign in to comment.