Skip to content

Commit

Permalink
feat: add redis caching (#179)
Browse files Browse the repository at this point in the history
Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>
  • Loading branch information
s0nicboOm authored and ab93 committed May 11, 2023
1 parent 7834730 commit e884b90
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 57 deletions.
10 changes: 8 additions & 2 deletions numalogic/registry/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,10 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
"""
Returns whether the given artifact is stale or not, i.e. if
more time has elasped since it was last retrained.
more time has elapsed since it was last retrained.
Args:
artifact_data: ArtifactData object to look into
freq_hr: Frequency of retraining in hours
"""
raise NotImplementedError("Please implement this method!")

Expand Down Expand Up @@ -147,3 +146,10 @@ def delete(self, key: str) -> None:
Implement this method for your custom cache.
"""
raise NotImplementedError("Please implement this method!")

def clear(self) -> None:
r"""
Clears the cache.
Implement this method for your custom cache.
"""
raise NotImplementedError("Please implement this method!")
3 changes: 3 additions & 0 deletions numalogic/registry/localcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,6 @@ def save(self, key: str, artifact: ArtifactData) -> None:

def delete(self, key: str) -> Optional[ArtifactData]:
return self.__cache.pop(key, default=None)

def clear(self) -> None:
self.__cache.clear()
117 changes: 74 additions & 43 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from redis.exceptions import RedisError

from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.registry import ArtifactManager, ArtifactData, ArtifactCache
from numalogic.registry._serialize import loads, dumps
from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError
from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT
Expand All @@ -19,6 +19,7 @@ class RedisRegistry(ArtifactManager):
Args:
client: Take in the reids client already established/created
ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800)
cache_registry: Cache registry to use (default = None)
Examples
--------
Expand All @@ -34,16 +35,18 @@ class RedisRegistry(ArtifactManager):
>>> loaded_artifact = registry.load(skeys, dkeys)
"""

__slots__ = ("client", "ttl")
__slots__ = ("client", "ttl", "cache_registry")

def __init__(
self,
client: redis_client_t,
ttl: int = 604800,
cache_registry: ArtifactCache = None,
):
super().__init__("")
self.client = client
self.ttl = ttl
self.cache_registry = cache_registry

@staticmethod
def construct_key(skeys: KEYS, dkeys: KEYS) -> str:
Expand Down Expand Up @@ -81,24 +84,65 @@ def get_version(key: str) -> str:
"""
return key.split("::")[-1]

def __get_model_key(self, latest: bool, version: str, key: str) -> str:
if latest:
production_key = self.__construct_production_key(key)
if not self.client.exists(production_key):
raise ModelKeyNotFound(
f"Production key: {production_key}, Not Found !!!\n Exiting....."
)
model_key = self.client.get(production_key)
if not self.client.exists(model_key):
raise ModelKeyNotFound(
"Production key = {} is pointing to the key: {} that "
"is missing the redis registry".format(production_key, model_key)
)
else:
model_key = self.__construct_version_key(key, version)
if not self.client.exists(model_key):
raise ModelKeyNotFound("Could not find model key with key: %s" % model_key)
return model_key
def _load_from_cache(self, key: str) -> Optional[ArtifactData]:
if not self.cache_registry:
return None
return self.cache_registry.load(key)

def _save_in_cache(self, key: str, artifact_data: ArtifactData) -> None:
if self.cache_registry:
self.cache_registry.save(key, artifact_data)

def _clear_cache(self, key: Optional[str] = None) -> Optional[ArtifactData]:
if self.cache_registry:
if key:
return self.cache_registry.delete(key)
return self.cache_registry.clear()
return None

def __get_artifact_data(
self,
model_key: str,
) -> ArtifactData:
(
serialized_artifact,
artifact_version,
artifact_timestamp,
serialized_metadata,
) = self.client.hmget(name=model_key, keys=["artifact", "version", "timestamp", "metadata"])
deserialized_artifact = loads(serialized_artifact)
deserialized_metadata = None
if serialized_metadata:
deserialized_metadata = loads(serialized_metadata)
return ArtifactData(
artifact=deserialized_artifact,
metadata=deserialized_metadata,
extras={
"timestamp": float(artifact_timestamp.decode()),
"version": artifact_version.decode(),
},
)

def __load_latest_artifact(self, key: str) -> ArtifactData:
cached_artifact = self._load_from_cache(key)
if cached_artifact:
return cached_artifact
production_key = self.__construct_production_key(key)
if not self.client.exists(production_key):
raise ModelKeyNotFound(
f"Production key: {production_key}, Not Found !!!\n Exiting....."
)
model_key = self.client.get(production_key)
_LOGGER.info("Production key, %s, is pointing to the key : %s", production_key, model_key)
return self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key)

def __load_version_artifact(self, version: str, key: str) -> ArtifactData:
model_key = self.__construct_version_key(key, version)
if not self.client.exists(model_key):
raise ModelKeyNotFound("Could not find model key with key: %s" % model_key)
return self.__get_artifact_data(
model_key=model_key,
)

def __save_artifact(
self, pipe, artifact: artifact_t, metadata: META_T, key: KEYS, version: str
Expand All @@ -118,7 +162,7 @@ def __save_artifact(
mapping={
"artifact": serialized_artifact,
"version": str(version),
"timestamp": int(time.time()),
"timestamp": time.time(),
"metadata": serialized_metadata,
},
)
Expand Down Expand Up @@ -147,30 +191,15 @@ def load(
raise ValueError("Either One of 'latest' or 'version' needed in load method call")
key = self.construct_key(skeys, dkeys)
try:
model_key = self.__get_model_key(latest, version, key)
(
serialized_artifact,
artifact_version,
artifact_timestamp,
serialized_metadata,
) = self.client.hmget(
name=model_key, keys=["artifact", "version", "timestamp", "metadata"]
)
deserialized_artifact = loads(serialized_artifact)
deserialized_metadata = None
if serialized_metadata:
deserialized_metadata = loads(serialized_metadata)
if latest:
artifact_data = self.__load_latest_artifact(key)
self._save_in_cache(key, artifact_data)
else:
artifact_data = self.__load_version_artifact(version, key)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
return ArtifactData(
artifact=deserialized_artifact,
metadata=deserialized_metadata,
extras={
"timestamp": artifact_timestamp.decode(),
"version": artifact_version.decode(),
},
)
return artifact_data

def save(
self,
Expand Down Expand Up @@ -229,12 +258,14 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
self._clear_cache(del_key)

@staticmethod
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
"""
Returns whether the given artifact is stale or not, i.e. if
more time has elasped since it was last retrained.
more time has elapsed since it was last retrained.
Args:
artifact_data: ArtifactData object to look into
freq_hr: Frequency of retraining in hours
Expand Down
8 changes: 8 additions & 0 deletions tests/registry/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def test_cache(self):
cache_reg.load("m1")
with self.assertRaises(NotImplementedError):
cache_reg.delete("m1")
with self.assertRaises(NotImplementedError):
cache_reg.clear()


class TestLocalLRUCache(unittest.TestCase):
Expand Down Expand Up @@ -63,6 +65,12 @@ def test_delete(self):
cache_registry.delete("m1")
self.assertIsNone(cache_registry.load("m1"))

def test_clear(self):
cache_registry = LocalLRUCache(cachesize=2, ttl=1)
cache_registry.save("m1", ArtifactData(VanillaAE(10, 1), metadata={}, extras={}))
cache_registry.clear()
self.assertIsNone(cache_registry.load("m1"))


if __name__ == "__main__":
unittest.main()
13 changes: 6 additions & 7 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def test_save_model(self):
skeys = self.skeys
dkeys = self.dkeys
status = ml.save(skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234")
print(status)
mock_status = "READY"
self.assertEqual(mock_status, status.status)

Expand Down Expand Up @@ -209,7 +208,7 @@ def test_load_model_when_no_model_02(self):
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch(
"mlflow.tracking.MlflowClient.transition_model_version_stage",
Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)})),
Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_LIMIT_EXCEEDED)})),
)
def test_transition_stage_fail(self):
fake_skeys = ["Fakemodel_"]
Expand Down Expand Up @@ -250,7 +249,6 @@ def test_delete_model_when_no_model(self):
ml = MLflowRegistry(TRACKING_URI)
with self.assertLogs(level="ERROR") as log:
ml.delete(skeys=fake_skeys, dkeys=fake_dkeys, version="1")
print(log.output)
self.assertTrue(log.output)

@patch("mlflow.pytorch.log_model", Mock(side_effect=RuntimeError))
Expand All @@ -268,21 +266,22 @@ def test_save_failed(self):
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch(
"mlflow.pytorch.load_model",
"mlflow.tracking.MlflowClient.get_latest_versions",
Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)})),
)
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_load_no_model_found(self):
ml = MLflowRegistry(TRACKING_URI, artifact_type="pytorch")
skeys = self.skeys
dkeys = self.dkeys
self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys))
data = ml.load(skeys=skeys, dkeys=dkeys)
self.assertIsNone(data)

@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch(
"mlflow.pytorch.load_model",
Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_LIMIT_EXCEEDED)})),
"mlflow.tracking.MlflowClient.get_latest_versions",
Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)})),
)
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_load_other_mlflow_err(self):
Expand Down
67 changes: 62 additions & 5 deletions tests/registry/test_redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.preprocessing import StandardScaler

from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.registry import RedisRegistry
from numalogic.registry import RedisRegistry, LocalLRUCache, ArtifactData
from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError


Expand All @@ -24,24 +24,73 @@ def setUpClass(cls) -> None:
cls.redis_client = fakeredis.FakeStrictRedis(server=server, decode_responses=False)

def setUp(self):
self.registry = RedisRegistry(client=self.redis_client)
self.cache = LocalLRUCache(cachesize=4, ttl=300)
self.registry = RedisRegistry(
client=self.redis_client,
cache_registry=self.cache,
)
self.registry_no_cache = RedisRegistry(client=self.redis_client)

def tearDown(self) -> None:
self.registry.client.flushall()
self.registry_no_cache.client.flushall()
self.cache.clear()

def test_no_cache(self):
self.assertIsNone(
self.registry_no_cache._save_in_cache(
"key", ArtifactData(artifact=self.pytorch_model, extras={}, metadata={})
)
)
self.assertIsNone(self.registry_no_cache._load_from_cache("key"))
self.assertIsNone(self.registry_no_cache._clear_cache("key"))

def test_construct_key(self):
key = RedisRegistry.construct_key(["model_", "nnet"], ["error1"])
self.assertEqual("model_:nnet::error1", key)

def test_save_model_without_metadata(self):
def test_save_model_without_metadata_cache_hit(self):
save_version = self.registry.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
resave_version = self.registry.save(
data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys)
self.assertEqual(data.extras["version"], save_version)
resave_version1 = self.registry.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
resave_data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys)
self.assertEqual(save_version, "0")
self.assertEqual(resave_version, "1")
self.assertEqual(resave_version1, "1")
self.assertEqual(resave_data.extras["version"], "0")

def test_save_load_without_cache(self):
save_version = self.registry_no_cache.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
data = self.registry_no_cache.load(skeys=self.skeys, dkeys=self.dkeys)
self.assertEqual(data.extras["version"], save_version)
resave_version1 = self.registry_no_cache.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
resave_data = self.registry_no_cache.load(skeys=self.skeys, dkeys=self.dkeys)
self.assertEqual(save_version, "0")
self.assertEqual(resave_version1, "1")
self.assertEqual(resave_data.extras["version"], "1")

def test_save_model_without_metadata_cache_miss(self):
save_version = self.registry.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys)
self.assertEqual(data.extras["version"], save_version)
resave_version1 = self.registry.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
self.cache.clear()
resave_data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys)
self.assertEqual(save_version, "0")
self.assertEqual(resave_version1, "1")
self.assertEqual(resave_data.extras["version"], "1")

def test_load_model_without_metadata(self):
version = self.registry.save(
Expand All @@ -61,6 +110,14 @@ def test_load_model_with_metadata(self):
self.assertIsNotNone(data.metadata)
self.assertEqual(data.extras["version"], version)

def test_delete_model(self):
version = self.registry.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
self.registry.delete(skeys=self.skeys, dkeys=self.dkeys, version=version)
with self.assertRaises(ModelKeyNotFound):
self.registry.load(skeys=self.skeys, dkeys=self.dkeys)

def test_load_model_with_version(self):
version = self.registry.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
Expand Down

0 comments on commit e884b90

Please sign in to comment.