-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Create an MlflowMetricDataSet to simplify metric logging (#73)
- Loading branch information
1 parent
d626c8e
commit 1eecda2
Showing
10 changed files
with
517 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .mlflow_metric_dataset import MlflowMetricDataSet | ||
from .mlflow_metrics_dataset import MlflowMetricsDataSet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from typing import Any, Dict, Union | ||
|
||
import mlflow | ||
from kedro.io import AbstractDataSet | ||
from mlflow.tracking import MlflowClient | ||
|
||
|
||
class MlflowAbstractMetricDataSet(AbstractDataSet): | ||
def __init__( | ||
self, | ||
run_id: str = None, | ||
key: str = None, | ||
load_args: Dict[str, Any] = None, | ||
save_args: Dict[str, Any] = None, | ||
): | ||
"""Initialise MlflowMetricsDataSet. | ||
Args: | ||
run_id (str): The ID of the mlflow run where the metric should be logged | ||
""" | ||
|
||
self.key = key | ||
self.run_id = run_id | ||
self._load_args = load_args or {} | ||
self._save_args = save_args or {} | ||
self._logging_activated = True # by default, logging is activated! | ||
|
||
@property | ||
def run_id(self) -> Union[str, None]: | ||
"""Get run id.""" | ||
|
||
run = mlflow.active_run() | ||
if (self._run_id is None) and (run is not None): | ||
# if no run_id is specified, we try to retrieve the current run | ||
# this is useful because during a kedro run, we want to be able to retrieve | ||
# the metric from the active run to be able to reload a metric | ||
# without specifying the (unknown) run id | ||
return run.info.run_id | ||
|
||
# else we return the _run_id which can eventually be None. | ||
# In this case, saving will work (a new run will be created) | ||
# but loading will fail, | ||
# according to mlflow's behaviour | ||
return self._run_id | ||
|
||
@run_id.setter | ||
def run_id(self, run_id: str): | ||
self._run_id = run_id | ||
|
||
# we want to be able to turn logging off for an entire pipeline run | ||
# To avoid that a single call to a dataset in the catalog creates a new run automatically | ||
# we want to be able to turn everything off | ||
@property | ||
def _logging_activated(self): | ||
return self.__logging_activated | ||
|
||
@_logging_activated.setter | ||
def _logging_activated(self, flag): | ||
if not isinstance(flag, bool): | ||
raise ValueError(f"_logging_activated must be a boolean, got {type(flag)}") | ||
self.__logging_activated = flag | ||
|
||
def _validate_run_id(self): | ||
if self.run_id is None: | ||
raise ValueError( | ||
"You must either specify a run_id or have a mlflow active run opened. Use mlflow.start_run() if necessary." | ||
) | ||
|
||
def _exists(self) -> bool: | ||
"""Check if the metric exists in remote mlflow storage exists. | ||
Returns: | ||
bool: Does the metric name exist in the given run_id? | ||
""" | ||
mlflow_client = MlflowClient() | ||
run_id = self.run_id # will get the active run if nothing is specified | ||
run = mlflow_client.get_run(run_id) if run_id else mlflow.active_run() | ||
|
||
flag_exist = self.key in run.data.metrics.keys() if run else False | ||
return flag_exist | ||
|
||
def _describe(self) -> Dict[str, Any]: | ||
"""Describe MLflow metrics dataset. | ||
Returns: | ||
Dict[str, Any]: Dictionary with MLflow metrics dataset description. | ||
""" | ||
return { | ||
"key": self.key, | ||
"run_id": self.run_id, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from copy import deepcopy | ||
from typing import Any, Dict | ||
|
||
from mlflow.tracking import MlflowClient | ||
|
||
from kedro_mlflow.io.metrics.mlflow_abstract_metric_dataset import ( | ||
MlflowAbstractMetricDataSet, | ||
) | ||
|
||
|
||
class MlflowMetricDataSet(MlflowAbstractMetricDataSet): | ||
SUPPORTED_SAVE_MODES = {"overwrite", "append"} | ||
DEFAULT_SAVE_MODE = "overwrite" | ||
|
||
def __init__( | ||
self, | ||
run_id: str = None, | ||
key: str = None, | ||
load_args: Dict[str, Any] = None, | ||
save_args: Dict[str, Any] = None, | ||
): | ||
"""Initialise MlflowMetricDataSet. | ||
Args: | ||
run_id (str): The ID of the mlflow run where the metric should be logged | ||
""" | ||
|
||
super().__init__(run_id, key, load_args, save_args) | ||
|
||
# We add an extra argument mode="overwrite" / "append" to enable logging update an existing metric | ||
# this is not an offical mlflow argument for log_metric, so we separate it from the others | ||
# "overwrite" corresponds to the default mlflow behaviour | ||
self.mode = self._save_args.pop("mode", self.DEFAULT_SAVE_MODE) | ||
|
||
def _load(self): | ||
self._validate_run_id() | ||
mlflow_client = MlflowClient() | ||
metric_history = mlflow_client.get_metric_history( | ||
run_id=self.run_id, key=self.key | ||
) # gets active run if no run_id was given | ||
|
||
# the metric history is always a list of mlflow.entities.metric.Metric | ||
# we want the value of the last one stored because this dataset only deal with one single metric | ||
step = self._load_args.get("step") | ||
|
||
if step is None: | ||
# we take the last value recorded | ||
metric_value = metric_history[-1].value | ||
else: | ||
# we should take the last historical value with the given step | ||
# (it is possible to have several values with the same step) | ||
metric_value = next( | ||
metric.value | ||
for metric in reversed(metric_history) | ||
if metric.step == step | ||
) | ||
|
||
return metric_value | ||
|
||
def _save(self, data: float): | ||
if self._logging_activated: | ||
self._validate_run_id() | ||
run_id = ( | ||
self.run_id | ||
) # we access it once instead of calling self.run_id everywhere to avoid looking or an active run each time | ||
|
||
mlflow_client = MlflowClient() | ||
|
||
# get the metric history if it has been saved previously to ensure | ||
# to retrieve the right data | ||
# reminder: this is True even if no run_id was originally specified but a run is active | ||
metric_history = ( | ||
mlflow_client.get_metric_history(run_id=run_id, key=self.key) | ||
if self._exists() | ||
else [] | ||
) | ||
|
||
save_args = deepcopy(self._save_args) | ||
step = save_args.pop("step", None) | ||
if step is None: | ||
if self.mode == "overwrite": | ||
step = max([metric.step for metric in metric_history], default=0) | ||
elif self.mode == "append": | ||
# I put a max([]) default to -1 so that default "step" equals 0 | ||
step = ( | ||
max([metric.step for metric in metric_history], default=-1) + 1 | ||
) | ||
else: | ||
raise ValueError( | ||
f"save_args['mode'] must be one of {self.SUPPORTED_SAVE_MODES}, got '{self.mode}' instead." | ||
) | ||
|
||
mlflow_client.log_metric( | ||
run_id=run_id, | ||
key=self.key, | ||
value=data, | ||
step=step, | ||
**save_args, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.