Skip to content

Commit

Permalink
chore!: Auto detect instance type while mlflow model save (#190)
Browse files Browse the repository at this point in the history
Helps for a better config-driven registry instance creation.

- Avoid specifying mlflow instance type in the constructor
- Detect the correct handler during save
- Specify artifact type during load
- Update docs

Signed-off-by: Avik Basu <ab93@users.noreply.github.com>
  • Loading branch information
ab93 committed May 11, 2023
1 parent 7dd3ada commit d9d62ff
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 69 deletions.
33 changes: 21 additions & 12 deletions docs/ml-flow.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ Numalogic provides `MLflowRegistry`, to save and load models to/from MLflow.

Here, `tracking_uri` is the uri where mlflow server is running. The `static_keys` and `dynamic_keys` are used to form a unique key for the model.

The `primary_artifact` would be the main model, and `secondary_artifacts` can be used to save any pre-processing models like scalers.

The `artifact` would be the model or transformer object that needs to be saved.
A dictionary of metadata can also be saved along with the artifact.
```python
from numalogic.registry import MLflowRegistry
from numalogic.models.autoencoder.variants import VanillaAE

model = VanillaAE(seq_len=10)

# static and dynamic keys are used to look up a model
static_keys = ["synthetic", "3ts"]
dynamic_keys = ["minmaxscaler", "sparseconv1d"]
static_keys = ["model", "autoencoder"]
dynamic_keys = ["vanilla", "seq10"]

registry = MLflowRegistry(tracking_uri="http://0.0.0.0:5000", artifact_type="pytorch")
registry = MLflowRegistry(tracking_uri="http://0.0.0.0:5000")
registry.save(
skeys=static_keys,
dkeys=dynamic_keys,
primary_artifact=model,
secondary_artifacts={"preproc": scaler},
skeys=static_keys, dkeys=dynamic_keys, artifact=model, seq_len=10, lr=0.001
)
```

Expand All @@ -46,10 +46,19 @@ registry.save(
Once, the models are save to MLflow, the `load` function of `MLflowRegistry` can be used to load the model.

```python
from numalogic.registry import MLflowRegistry

static_keys = ["model", "autoencoder"]
dynamic_keys = ["vanilla", "seq10"]

registry = MLflowRegistry(tracking_uri="http://0.0.0.0:8080")
artifact_dict = registry.load(skeys=static_keys, dkeys=dynamic_keys)
scaler = artifact_dict["secondary_artifacts"]["preproc"]
model = artifact_dict["primary_artifact"]
artifact_data = registry.load(
skeys=static_keys, dkeys=dynamic_keys, artifact_type="pytorch"
)

# get the model and metadata
model = artifact_data.artifact
model_metadata = artifact_data.metadata
```

For more details, please refer to [MLflow Model Registry](https://www.mlflow.org/docs/latest/model-registry.html#)
78 changes: 52 additions & 26 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
from mlflow.tracking import MlflowClient
from sklearn.base import BaseEstimator
from torch import nn

from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.registry.artifact import ArtifactCache
Expand Down Expand Up @@ -51,9 +53,6 @@ class MLflowRegistry(ArtifactManager):
Args:
tracking_uri: the tracking server uri to use for mlflow
artifact_type: the type of primary artifact to use
supported values include:
{"pytorch", "sklearn", "tensorflow", "pyfunc"}
models_to_retain: number of models to retain in the DB (default = 5)
model_stage: Staging environment from where to load the latest model from (mlflow )
supported values include:
Expand All @@ -67,18 +66,17 @@ class MLflowRegistry(ArtifactManager):
>>>
>>> data = [[0, 0], [0, 0], [1, 1], [1, 1]]
>>> scaler = StandardScaler.fit(data)
>>> registry = MLflowRegistry(tracking_uri="http://0.0.0.0:8080", artifact_type="pytorch")
>>> registry = MLflowRegistry(tracking_uri="http://0.0.0.0:8080")
>>> registry.save(skeys=["model"], dkeys=["AE"], artifact=VanillaAE(10))
>>> artifact_data = registry.load(skeys=["model"], dkeys=["AE"])
>>> artifact_data = registry.load(skeys=["model"], dkeys=["AE"], artifact_type="pytorch")
"""

__slots__ = ("client", "handler", "models_to_retain", "model_stage", "cache_registry")
__slots__ = ("client", "models_to_retain", "model_stage", "cache_registry")
_TRACKING_URI = None

def __new__(
cls,
tracking_uri: Optional[str],
artifact_type: str = "pytorch",
models_to_retain: int = 5,
model_stage: ModelStage = ModelStage.PRODUCTION,
cache_registry: ArtifactCache = None,
Expand All @@ -93,30 +91,34 @@ def __new__(
def __init__(
self,
tracking_uri: str,
artifact_type: str = "pytorch",
models_to_retain: int = 5,
model_stage: str = ModelStage.PRODUCTION,
cache_registry: ArtifactCache = None,
):
super().__init__(tracking_uri)
mlflow.set_tracking_uri(tracking_uri)
self.client = MlflowClient()
self.handler = self.mlflow_handler(artifact_type)
self.models_to_retain = models_to_retain
self.model_stage = model_stage
self.cache_registry = cache_registry

@staticmethod
def mlflow_handler(artifact_type: str):
def handler_from_obj(artifact: artifact_t):
if isinstance(artifact, nn.Module):
return mlflow.pytorch
if isinstance(artifact, BaseEstimator):
return mlflow.sklearn
return mlflow.pyfunc

@staticmethod
def handler_from_type(artifact_type: str):
"""
Helper method to return the right handler given the artifact type.
"""
if artifact_type == "pytorch":
return mlflow.pytorch
if artifact_type == "sklearn":
return mlflow.sklearn
if artifact_type == "tensorflow":
return mlflow.tensorflow
if artifact_type == "pyfunc":
return mlflow.pyfunc
raise NotImplementedError("Artifact Type not Implemented")
Expand All @@ -136,8 +138,25 @@ def _clear_cache(self, key: str) -> Optional[ArtifactData]:
return None

def load(
self, skeys: KEYS, dkeys: KEYS, latest: bool = True, version: str = None
self,
skeys: KEYS,
dkeys: KEYS,
latest: bool = True,
version: str = None,
artifact_type: str = "pytorch",
) -> Optional[ArtifactData]:
"""
Load the artifact from the registry. The artifact is loaded from the cache if available.
Args:
skeys: Static keys
dkeys: Dynamic keys
latest: Load the latest version of the model (default = True)
version: Version of the model to load (default = None)
artifact_type: Type of the artifact to load (default = "pytorch")
Returns:
The loaded ArtifactData object if available otherwise None
"""
model_key = self.construct_key(skeys, dkeys)

if (latest and version) or (not latest and not version):
Expand All @@ -154,30 +173,35 @@ def load(
version_info = version_info[-1]
else:
version_info = self.client.get_model_version(model_key, version)
model, metadata = self.__load_artifacts(skeys, dkeys, version_info)
model, metadata = self.__load_artifacts(skeys, dkeys, version_info, artifact_type)
except RestException as mlflow_err:
if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST:
_LOGGER.info("Model not found with key: %s", model_key)
else:
_LOGGER.exception(
"Mlflow error when loading a model with key: %s: %r", model_key, mlflow_err
)
return None
return self.__log_mlflow_err(mlflow_err, model_key)
except ModelVersionError as model_missing_err:
_LOGGER.error(
"No Model found found in %s ERROR: %r", self.model_stage, model_missing_err
)
return None
except Exception as ex:
_LOGGER.exception("Unexpected error: %s", ex)
_LOGGER.exception("Unexpected error: %r", ex)
return None
else:
artifact_data = ArtifactData(
artifact=model, metadata=metadata, extras=dict(version_info)
)
self._save_in_cache(model_key, artifact_data)
# save in cache if loading the latest version
if latest:
self._save_in_cache(model_key, artifact_data)
return artifact_data

@staticmethod
def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None:
if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST:
_LOGGER.info("Model not found with key: %s", model_key)
else:
_LOGGER.exception(
"Mlflow error when loading a model with key: %s: %r", model_key, mlflow_err
)

def save(
self,
skeys: KEYS,
Expand All @@ -199,9 +223,10 @@ def save(
mlflow ModelVersion instance
"""
model_key = self.construct_key(skeys, dkeys)
handler = self.handler_from_obj(artifact)
try:
mlflow.start_run(run_id=run_id)
self.handler.log_model(artifact, "model", registered_model_name=model_key)
handler.log_model(artifact, "model", registered_model_name=model_key)
if metadata:
mlflow.log_params(metadata)
model_version = self.transition_stage(skeys=skeys, dkeys=dkeys)
Expand Down Expand Up @@ -298,10 +323,11 @@ def __delete_stale_models(self, skeys: KEYS, dkeys: KEYS):
_LOGGER.debug("Deleted stale model version : %s", stale_model.version)

def __load_artifacts(
self, skeys: KEYS, dkeys: KEYS, version_info: ModelVersion
self, skeys: KEYS, dkeys: KEYS, version_info: ModelVersion, artifact_type: str
) -> tuple[artifact_t, dict[str, Any]]:
model_key = self.construct_key(skeys, dkeys)
model = self.handler.load_model(model_uri=f"models:/{model_key}/{version_info.version}")
handler = self.handler_from_type(artifact_type)
model = handler.load_model(model_uri=f"models:/{model_key}/{version_info.version}")
_LOGGER.info("Successfully loaded model %s from Mlflow", model_key)

run_info = mlflow.get_run(version_info.run_id)
Expand Down
6 changes: 3 additions & 3 deletions numalogic/tools/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from sklearn.base import BaseEstimator
from torch import nn

artifact_t = TypeVar("artifact_t", bound=Union[nn.Module, BaseEstimator])
META_T = TypeVar("META_T", bound=dict[str, Union[str, list, dict]])
META_VT = TypeVar("META_VT", bound=Union[str, list, dict])
artifact_t = TypeVar("artifact_t", bound=Union[nn.Module, BaseEstimator], covariant=True)
META_T = TypeVar("META_T", bound=dict[str, Union[str, float, int, list, dict]])
META_VT = TypeVar("META_VT", str, int, float, list, dict)
EXTRA_T = TypeVar("EXTRA_T", bound=dict[str, Union[str, list, dict]])
redis_client_t = TypeVar("redis_client_t", bound=AbstractRedis, covariant=True)
KEYS = TypeVar("KEYS", bound=Sequence[str], covariant=True)
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.4.dev5"
version = "0.4.a0"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
Loading

0 comments on commit d9d62ff

Please sign in to comment.