diff --git a/numalogic/registry/artifact.py b/numalogic/registry/artifact.py index 05bbd6de..e3815dd8 100644 --- a/numalogic/registry/artifact.py +++ b/numalogic/registry/artifact.py @@ -79,11 +79,10 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None: def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: """ Returns whether the given artifact is stale or not, i.e. if - more time has elasped since it was last retrained. + more time has elapsed since it was last retrained. Args: artifact_data: ArtifactData object to look into freq_hr: Frequency of retraining in hours - """ raise NotImplementedError("Please implement this method!") @@ -147,3 +146,10 @@ def delete(self, key: str) -> None: Implement this method for your custom cache. """ raise NotImplementedError("Please implement this method!") + + def clear(self) -> None: + r""" + Clears the cache. + Implement this method for your custom cache. + """ + raise NotImplementedError("Please implement this method!") diff --git a/numalogic/registry/localcache.py b/numalogic/registry/localcache.py index f7c9ff0f..1980533b 100644 --- a/numalogic/registry/localcache.py +++ b/numalogic/registry/localcache.py @@ -41,3 +41,6 @@ def save(self, key: str, artifact: ArtifactData) -> None: def delete(self, key: str) -> Optional[ArtifactData]: return self.__cache.pop(key, default=None) + + def clear(self) -> None: + self.__cache.clear() diff --git a/numalogic/registry/redis_registry.py b/numalogic/registry/redis_registry.py index 851807be..880c0fd1 100644 --- a/numalogic/registry/redis_registry.py +++ b/numalogic/registry/redis_registry.py @@ -5,7 +5,7 @@ from redis.exceptions import RedisError -from numalogic.registry import ArtifactManager, ArtifactData +from numalogic.registry import ArtifactManager, ArtifactData, ArtifactCache from numalogic.registry._serialize import loads, dumps from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT @@ -19,6 +19,7 @@ class RedisRegistry(ArtifactManager): Args: client: Take in the reids client already established/created ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800) + cache_registry: Cache registry to use (default = None) Examples -------- @@ -34,16 +35,18 @@ class RedisRegistry(ArtifactManager): >>> loaded_artifact = registry.load(skeys, dkeys) """ - __slots__ = ("client", "ttl") + __slots__ = ("client", "ttl", "cache_registry") def __init__( self, client: redis_client_t, ttl: int = 604800, + cache_registry: ArtifactCache = None, ): super().__init__("") self.client = client self.ttl = ttl + self.cache_registry = cache_registry @staticmethod def construct_key(skeys: KEYS, dkeys: KEYS) -> str: @@ -81,24 +84,65 @@ def get_version(key: str) -> str: """ return key.split("::")[-1] - def __get_model_key(self, latest: bool, version: str, key: str) -> str: - if latest: - production_key = self.__construct_production_key(key) - if not self.client.exists(production_key): - raise ModelKeyNotFound( - f"Production key: {production_key}, Not Found !!!\n Exiting....." - ) - model_key = self.client.get(production_key) - if not self.client.exists(model_key): - raise ModelKeyNotFound( - "Production key = {} is pointing to the key: {} that " - "is missing the redis registry".format(production_key, model_key) - ) - else: - model_key = self.__construct_version_key(key, version) - if not self.client.exists(model_key): - raise ModelKeyNotFound("Could not find model key with key: %s" % model_key) - return model_key + def _load_from_cache(self, key: str) -> Optional[ArtifactData]: + if not self.cache_registry: + return None + return self.cache_registry.load(key) + + def _save_in_cache(self, key: str, artifact_data: ArtifactData) -> None: + if self.cache_registry: + self.cache_registry.save(key, artifact_data) + + def _clear_cache(self, key: Optional[str] = None) -> Optional[ArtifactData]: + if self.cache_registry: + if key: + return self.cache_registry.delete(key) + return self.cache_registry.clear() + return None + + def __get_artifact_data( + self, + model_key: str, + ) -> ArtifactData: + ( + serialized_artifact, + artifact_version, + artifact_timestamp, + serialized_metadata, + ) = self.client.hmget(name=model_key, keys=["artifact", "version", "timestamp", "metadata"]) + deserialized_artifact = loads(serialized_artifact) + deserialized_metadata = None + if serialized_metadata: + deserialized_metadata = loads(serialized_metadata) + return ArtifactData( + artifact=deserialized_artifact, + metadata=deserialized_metadata, + extras={ + "timestamp": float(artifact_timestamp.decode()), + "version": artifact_version.decode(), + }, + ) + + def __load_latest_artifact(self, key: str) -> ArtifactData: + cached_artifact = self._load_from_cache(key) + if cached_artifact: + return cached_artifact + production_key = self.__construct_production_key(key) + if not self.client.exists(production_key): + raise ModelKeyNotFound( + f"Production key: {production_key}, Not Found !!!\n Exiting....." + ) + model_key = self.client.get(production_key) + _LOGGER.info("Production key, %s, is pointing to the key : %s", production_key, model_key) + return self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key) + + def __load_version_artifact(self, version: str, key: str) -> ArtifactData: + model_key = self.__construct_version_key(key, version) + if not self.client.exists(model_key): + raise ModelKeyNotFound("Could not find model key with key: %s" % model_key) + return self.__get_artifact_data( + model_key=model_key, + ) def __save_artifact( self, pipe, artifact: artifact_t, metadata: META_T, key: KEYS, version: str @@ -118,7 +162,7 @@ def __save_artifact( mapping={ "artifact": serialized_artifact, "version": str(version), - "timestamp": int(time.time()), + "timestamp": time.time(), "metadata": serialized_metadata, }, ) @@ -147,30 +191,15 @@ def load( raise ValueError("Either One of 'latest' or 'version' needed in load method call") key = self.construct_key(skeys, dkeys) try: - model_key = self.__get_model_key(latest, version, key) - ( - serialized_artifact, - artifact_version, - artifact_timestamp, - serialized_metadata, - ) = self.client.hmget( - name=model_key, keys=["artifact", "version", "timestamp", "metadata"] - ) - deserialized_artifact = loads(serialized_artifact) - deserialized_metadata = None - if serialized_metadata: - deserialized_metadata = loads(serialized_metadata) + if latest: + artifact_data = self.__load_latest_artifact(key) + self._save_in_cache(key, artifact_data) + else: + artifact_data = self.__load_version_artifact(version, key) except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err else: - return ArtifactData( - artifact=deserialized_artifact, - metadata=deserialized_metadata, - extras={ - "timestamp": artifact_timestamp.decode(), - "version": artifact_version.decode(), - }, - ) + return artifact_data def save( self, @@ -229,12 +258,14 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None: ) except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err + else: + self._clear_cache(del_key) @staticmethod def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: """ Returns whether the given artifact is stale or not, i.e. if - more time has elasped since it was last retrained. + more time has elapsed since it was last retrained. Args: artifact_data: ArtifactData object to look into freq_hr: Frequency of retraining in hours diff --git a/tests/registry/test_cache.py b/tests/registry/test_cache.py index dfb23050..f7edc9df 100644 --- a/tests/registry/test_cache.py +++ b/tests/registry/test_cache.py @@ -14,6 +14,8 @@ def test_cache(self): cache_reg.load("m1") with self.assertRaises(NotImplementedError): cache_reg.delete("m1") + with self.assertRaises(NotImplementedError): + cache_reg.clear() class TestLocalLRUCache(unittest.TestCase): @@ -63,6 +65,12 @@ def test_delete(self): cache_registry.delete("m1") self.assertIsNone(cache_registry.load("m1")) + def test_clear(self): + cache_registry = LocalLRUCache(cachesize=2, ttl=1) + cache_registry.save("m1", ArtifactData(VanillaAE(10, 1), metadata={}, extras={})) + cache_registry.clear() + self.assertIsNone(cache_registry.load("m1")) + if __name__ == "__main__": unittest.main() diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index be286694..388bc7eb 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -65,7 +65,6 @@ def test_save_model(self): skeys = self.skeys dkeys = self.dkeys status = ml.save(skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234") - print(status) mock_status = "READY" self.assertEqual(mock_status, status.status) @@ -209,7 +208,7 @@ def test_load_model_when_no_model_02(self): @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch( "mlflow.tracking.MlflowClient.transition_model_version_stage", - Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)})), + Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_LIMIT_EXCEEDED)})), ) def test_transition_stage_fail(self): fake_skeys = ["Fakemodel_"] @@ -250,7 +249,6 @@ def test_delete_model_when_no_model(self): ml = MLflowRegistry(TRACKING_URI) with self.assertLogs(level="ERROR") as log: ml.delete(skeys=fake_skeys, dkeys=fake_dkeys, version="1") - print(log.output) self.assertTrue(log.output) @patch("mlflow.pytorch.log_model", Mock(side_effect=RuntimeError)) @@ -268,7 +266,7 @@ def test_save_failed(self): @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) @patch( - "mlflow.pytorch.load_model", + "mlflow.tracking.MlflowClient.get_latest_versions", Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)})), ) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) @@ -276,13 +274,14 @@ def test_load_no_model_found(self): ml = MLflowRegistry(TRACKING_URI, artifact_type="pytorch") skeys = self.skeys dkeys = self.dkeys - self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys)) + data = ml.load(skeys=skeys, dkeys=dkeys) + self.assertIsNone(data) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) @patch( - "mlflow.pytorch.load_model", - Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_LIMIT_EXCEEDED)})), + "mlflow.tracking.MlflowClient.get_latest_versions", + Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)})), ) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) def test_load_other_mlflow_err(self): diff --git a/tests/registry/test_redis_registry.py b/tests/registry/test_redis_registry.py index 61904bed..ad84cf09 100644 --- a/tests/registry/test_redis_registry.py +++ b/tests/registry/test_redis_registry.py @@ -8,7 +8,7 @@ from sklearn.preprocessing import StandardScaler from numalogic.models.autoencoder.variants import VanillaAE -from numalogic.registry import RedisRegistry +from numalogic.registry import RedisRegistry, LocalLRUCache, ArtifactData from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError @@ -24,24 +24,73 @@ def setUpClass(cls) -> None: cls.redis_client = fakeredis.FakeStrictRedis(server=server, decode_responses=False) def setUp(self): - self.registry = RedisRegistry(client=self.redis_client) + self.cache = LocalLRUCache(cachesize=4, ttl=300) + self.registry = RedisRegistry( + client=self.redis_client, + cache_registry=self.cache, + ) + self.registry_no_cache = RedisRegistry(client=self.redis_client) def tearDown(self) -> None: self.registry.client.flushall() + self.registry_no_cache.client.flushall() + self.cache.clear() + + def test_no_cache(self): + self.assertIsNone( + self.registry_no_cache._save_in_cache( + "key", ArtifactData(artifact=self.pytorch_model, extras={}, metadata={}) + ) + ) + self.assertIsNone(self.registry_no_cache._load_from_cache("key")) + self.assertIsNone(self.registry_no_cache._clear_cache("key")) def test_construct_key(self): key = RedisRegistry.construct_key(["model_", "nnet"], ["error1"]) self.assertEqual("model_:nnet::error1", key) - def test_save_model_without_metadata(self): + def test_save_model_without_metadata_cache_hit(self): save_version = self.registry.save( skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model ) - resave_version = self.registry.save( + data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + self.assertEqual(data.extras["version"], save_version) + resave_version1 = self.registry.save( skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model ) + resave_data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) self.assertEqual(save_version, "0") - self.assertEqual(resave_version, "1") + self.assertEqual(resave_version1, "1") + self.assertEqual(resave_data.extras["version"], "0") + + def test_save_load_without_cache(self): + save_version = self.registry_no_cache.save( + skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model + ) + data = self.registry_no_cache.load(skeys=self.skeys, dkeys=self.dkeys) + self.assertEqual(data.extras["version"], save_version) + resave_version1 = self.registry_no_cache.save( + skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model + ) + resave_data = self.registry_no_cache.load(skeys=self.skeys, dkeys=self.dkeys) + self.assertEqual(save_version, "0") + self.assertEqual(resave_version1, "1") + self.assertEqual(resave_data.extras["version"], "1") + + def test_save_model_without_metadata_cache_miss(self): + save_version = self.registry.save( + skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model + ) + data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + self.assertEqual(data.extras["version"], save_version) + resave_version1 = self.registry.save( + skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model + ) + self.cache.clear() + resave_data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + self.assertEqual(save_version, "0") + self.assertEqual(resave_version1, "1") + self.assertEqual(resave_data.extras["version"], "1") def test_load_model_without_metadata(self): version = self.registry.save( @@ -61,6 +110,14 @@ def test_load_model_with_metadata(self): self.assertIsNotNone(data.metadata) self.assertEqual(data.extras["version"], version) + def test_delete_model(self): + version = self.registry.save( + skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model + ) + self.registry.delete(skeys=self.skeys, dkeys=self.dkeys, version=version) + with self.assertRaises(ModelKeyNotFound): + self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + def test_load_model_with_version(self): version = self.registry.save( skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model