From 10879596035dcaad3695709fe79a0eec647996fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yolan=20Honor=C3=A9-Roug=C3=A9?= Date: Wed, 18 Aug 2021 23:47:13 +0200 Subject: [PATCH] :sparkles: Create an MlflowMetricDataSet to simplify metric logging (#73) --- CHANGELOG.md | 4 + .../05_version_metrics.md | 62 +++++- kedro_mlflow/framework/hooks/pipeline_hook.py | 17 +- kedro_mlflow/io/metrics/__init__.py | 1 + .../metrics/mlflow_abstract_metric_dataset.py | 91 ++++++++ .../io/metrics/mlflow_metric_dataset.py | 98 +++++++++ tests/framework/hooks/test_all_hooks.py | 19 +- tests/framework/hooks/test_pipeline_hook.py | 36 +++- .../io/metrics/test_mlflow_metric_dataset.py | 202 ++++++++++++++++++ .../io/metrics/test_mlflow_metrics_dataset.py | 2 +- 10 files changed, 517 insertions(+), 15 deletions(-) create mode 100644 kedro_mlflow/io/metrics/mlflow_abstract_metric_dataset.py create mode 100644 kedro_mlflow/io/metrics/mlflow_metric_dataset.py create mode 100644 tests/io/metrics/test_mlflow_metric_dataset.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ae04fffa..8ffa2779 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Added + +- :sparkles: Create an ``MlflowMetricDataSet`` to simplify the existing metric API. It enables logging a single float as a metric, eventually automatically increasing the "step" if the metric is going to be updated during time ([#73](https://github.com/Galileo-Galilei/kedro-mlflow/issues/73)) + ### Fixed - :bug: Dictionnary parameters with integer keys are now properly logged in mlflow when ``flatten_dict_params`` is set to ``True`` in the ``mlflow.yml`` instead of raising a ``TypeError`` ([#224](https://github.com/Galileo-Galilei/kedro-mlflow/discussions/224)) diff --git a/docs/source/04_experimentation_tracking/05_version_metrics.md b/docs/source/04_experimentation_tracking/05_version_metrics.md index ba4f4803..f9422ade 100644 --- a/docs/source/04_experimentation_tracking/05_version_metrics.md +++ b/docs/source/04_experimentation_tracking/05_version_metrics.md @@ -6,9 +6,67 @@ MLflow defines a metric as "a (key, value) pair, where the value is numeric". Ea ## How to version metrics in a kedro project? -`kedro-mlflow` introduces a new ``AbstractDataSet`` called ``MlflowMetricsDataSet``. It is a wrapper around a dictionary with metrics which is returned by node and log metrics in MLflow. +`kedro-mlflow` introduces 2 new ``AbstractDataSet``: +- ``MlflowMetricDataSet`` which can log a float as a metric +- ``MlflowMetricsDataSet``. The first one It is a wrapper around a dictionary with metrics which is returned by node and log metrics in MLflow. -Since it is an ``AbstractDataSet``, it can be used with the YAML API. You can define it as: +### Saving a single float as a metric with ``MlflowMetricDataSet`` + +The ``MlflowMetricDataSet`` is an ``AbstractDataSet`` which enable to save or load a ``float`` as a mlflow metric. You must specify the ``key`` (i.e. the name to display in mlflow) when creating the dataset. Somes examples follow: + +- The most basic usage is to create the dataset and save a a value: + +```python +from kedro_mlflow.io.metrics import MlflowMetricDataSet + +metric_ds=MlflowMetricDataSet(key="my_metric") +with mlflow.start_run(): + metric_ds.save(0.3) # create a "my_metric=0.3" value in the "metric" field in mlflow UI +``` + +**Beware: Unlike mlflow default behaviour, if there is no active run, no run is created.** + +- You can also specify a ``run_id`` instead of logging in the active run: + +```python +from kedro_mlflow.io.metrics import MlflowMetricDataSet + +metric_ds=MlflowMetricDataSet(key="my_metric", run_id="123456789") +with mlflow.start_run(): + metric_ds.save(0.3) # create a "my_metric=0.3" value in the "metric" field of the run 123456789 +``` + +It is also possible to pass ``load_args`` and ``save_args`` to control which step should be logged (in case you have logged several step for the same metric.) ``save_args`` accepts a ``mode`` key which can be set to ``overwrite`` (mlflow default) or ``append``. In append mode, if no step is specified, saving the metric will "bump" the last existing step to create a linear history. **This is very useful if you have a monitoring pipeline which calculates a metric frequently to check the performance of a deployed model.** + +```python +from kedro_mlflow.io.metrics import MlflowMetricDataSet + +metric_ds=MlflowMetricDataSet(key="my_metric", load_args={"step": 1}, save_args={"mode": "append"}) + +with mlflow.start_run(): + metric_ds.save(0) # step 0 stored for "my_metric" + metric_ds.save(0.1) # step 1 stored for "my_metric" + metric_ds.save(0.2) # step 2 stored for "my_metric" + + my_metric=metric_ds.load() # value=0.1 (step number 1) +``` + +Since it is an ``AbstractDataSet``, it can be used with the YAML API in your ``catalog.yml``, e.g. : + +```yaml +my_model_metric: + type: kedro_mlflow.io.metrics.MlflowMetricDataSet + run_id: 123456 # OPTIONAL, you should likely let it empty to log in the current run + key: my_awesome_name # OPTIONAL: if not provided, the dataset name will be sued (here "my_model_metric") + load_args: + step: ... # OPTIONAL: likely not provided, unless you have a very good reason to do so + save_args: + step: ... # OPTIONAL: likely not provided, unless you have a very good reason to do so + mode: append # OPTIONAL: likely better than the default "overwrite". Will be ignored if "step" is provided. +``` + +### Saving several metrics with their entire history with ``MlflowMetricDataSet`` +Since it is an ``AbstractDataSet``, it can be used with the YAML API. You can define it in your ``catalog.yml`` as: ```yaml my_model_metrics: diff --git a/kedro_mlflow/framework/hooks/pipeline_hook.py b/kedro_mlflow/framework/hooks/pipeline_hook.py index 5482a42b..f3b0e9fd 100644 --- a/kedro_mlflow/framework/hooks/pipeline_hook.py +++ b/kedro_mlflow/framework/hooks/pipeline_hook.py @@ -16,7 +16,7 @@ from kedro_mlflow.framework.context import get_mlflow_config from kedro_mlflow.framework.hooks.utils import _assert_mlflow_enabled from kedro_mlflow.io.catalog.switch_catalog_logging import switch_catalog_logging -from kedro_mlflow.io.metrics import MlflowMetricsDataSet +from kedro_mlflow.io.metrics import MlflowMetricDataSet, MlflowMetricsDataSet from kedro_mlflow.mlflow import KedroPipelineModel from kedro_mlflow.pipeline.pipeline_ml import PipelineML from kedro_mlflow.utils import _parse_requirements @@ -47,6 +47,21 @@ def after_catalog_created( else: catalog._data_sets[name] = MlflowMetricsDataSet(prefix=name) + if isinstance(dataset, MlflowMetricDataSet) and dataset.key is None: + if dataset._run_id is not None: + catalog._data_sets[name] = MlflowMetricDataSet( + run_id=dataset._run_id, + key=name, + load_args=dataset._load_args, + save_args=dataset._save_args, + ) + else: + catalog._data_sets[name] = MlflowMetricDataSet( + key=name, + load_args=dataset._load_args, + save_args=dataset._save_args, + ) + @hook_impl def before_pipeline_run( self, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog diff --git a/kedro_mlflow/io/metrics/__init__.py b/kedro_mlflow/io/metrics/__init__.py index b95d3ed2..6252d181 100644 --- a/kedro_mlflow/io/metrics/__init__.py +++ b/kedro_mlflow/io/metrics/__init__.py @@ -1 +1,2 @@ +from .mlflow_metric_dataset import MlflowMetricDataSet from .mlflow_metrics_dataset import MlflowMetricsDataSet diff --git a/kedro_mlflow/io/metrics/mlflow_abstract_metric_dataset.py b/kedro_mlflow/io/metrics/mlflow_abstract_metric_dataset.py new file mode 100644 index 00000000..337af996 --- /dev/null +++ b/kedro_mlflow/io/metrics/mlflow_abstract_metric_dataset.py @@ -0,0 +1,91 @@ +from typing import Any, Dict, Union + +import mlflow +from kedro.io import AbstractDataSet +from mlflow.tracking import MlflowClient + + +class MlflowAbstractMetricDataSet(AbstractDataSet): + def __init__( + self, + key: str = None, + run_id: str = None, + load_args: Dict[str, Any] = None, + save_args: Dict[str, Any] = None, + ): + """Initialise MlflowMetricsDataSet. + + Args: + run_id (str): The ID of the mlflow run where the metric should be logged + """ + + self.key = key + self.run_id = run_id + self._load_args = load_args or {} + self._save_args = save_args or {} + self._logging_activated = True # by default, logging is activated! + + @property + def run_id(self) -> Union[str, None]: + """Get run id.""" + + run = mlflow.active_run() + if (self._run_id is None) and (run is not None): + # if no run_id is specified, we try to retrieve the current run + # this is useful because during a kedro run, we want to be able to retrieve + # the metric from the active run to be able to reload a metric + # without specifying the (unknown) run id + return run.info.run_id + + # else we return the _run_id which can eventually be None. + # In this case, saving will work (a new run will be created) + # but loading will fail, + # according to mlflow's behaviour + return self._run_id + + @run_id.setter + def run_id(self, run_id: str): + self._run_id = run_id + + # we want to be able to turn logging off for an entire pipeline run + # To avoid that a single call to a dataset in the catalog creates a new run automatically + # we want to be able to turn everything off + @property + def _logging_activated(self): + return self.__logging_activated + + @_logging_activated.setter + def _logging_activated(self, flag): + if not isinstance(flag, bool): + raise ValueError(f"_logging_activated must be a boolean, got {type(flag)}") + self.__logging_activated = flag + + def _validate_run_id(self): + if self.run_id is None: + raise ValueError( + "You must either specify a run_id or have a mlflow active run opened. Use mlflow.start_run() if necessary." + ) + + def _exists(self) -> bool: + """Check if the metric exists in remote mlflow storage exists. + + Returns: + bool: Does the metric name exist in the given run_id? + """ + mlflow_client = MlflowClient() + run_id = self.run_id # will get the active run if nothing is specified + run = mlflow_client.get_run(run_id) if run_id else mlflow.active_run() + + flag_exist = self.key in run.data.metrics.keys() if run else False + return flag_exist + + def _describe(self) -> Dict[str, Any]: + """Describe MLflow metrics dataset. + + Returns: + Dict[str, Any]: Dictionary with MLflow metrics dataset description. + """ + return { + "key": self.key, + "run_id": self.run_id, + } diff --git a/kedro_mlflow/io/metrics/mlflow_metric_dataset.py b/kedro_mlflow/io/metrics/mlflow_metric_dataset.py new file mode 100644 index 00000000..b0432e35 --- /dev/null +++ b/kedro_mlflow/io/metrics/mlflow_metric_dataset.py @@ -0,0 +1,98 @@ +from copy import deepcopy +from typing import Any, Dict + +from mlflow.tracking import MlflowClient + +from kedro_mlflow.io.metrics.mlflow_abstract_metric_dataset import ( + MlflowAbstractMetricDataSet, +) + + +class MlflowMetricDataSet(MlflowAbstractMetricDataSet): + SUPPORTED_SAVE_MODES = {"overwrite", "append"} + DEFAULT_SAVE_MODE = "overwrite" + + def __init__( + self, + key: str = None, + run_id: str = None, + load_args: Dict[str, Any] = None, + save_args: Dict[str, Any] = None, + ): + """Initialise MlflowMetricDataSet. + Args: + run_id (str): The ID of the mlflow run where the metric should be logged + """ + + super().__init__(key, run_id, load_args, save_args) + + # We add an extra argument mode="overwrite" / "append" to enable logging update an existing metric + # this is not an offical mlflow argument for log_metric, so we separate it from the others + # "overwrite" corresponds to the default mlflow behaviour + self.mode = self._save_args.pop("mode", self.DEFAULT_SAVE_MODE) + + def _load(self): + self._validate_run_id() + mlflow_client = MlflowClient() + metric_history = mlflow_client.get_metric_history( + run_id=self.run_id, key=self.key + ) # gets active run if no run_id was given + + # the metric history is always a list of mlflow.entities.metric.Metric + # we want the value of the last one stored because this dataset only deal with one single metric + step = self._load_args.get("step") + + if step is None: + # we take the last value recorded + metric_value = metric_history[-1].value + else: + # we should take the last historical value with the given step + # (it is possible to have several values with the same step) + metric_value = next( + metric.value + for metric in reversed(metric_history) + if metric.step == step + ) + + return metric_value + + def _save(self, data: float): + if self._logging_activated: + self._validate_run_id() + run_id = ( + self.run_id + ) # we access it once instead of calling self.run_id everywhere to avoid looking or an active run each time + + mlflow_client = MlflowClient() + + # get the metric history if it has been saved previously to ensure + # to retrieve the right data + # reminder: this is True even if no run_id was originally specified but a run is active + metric_history = ( + mlflow_client.get_metric_history(run_id=run_id, key=self.key) + if self._exists() + else [] + ) + + save_args = deepcopy(self._save_args) + step = save_args.pop("step", None) + if step is None: + if self.mode == "overwrite": + step = max([metric.step for metric in metric_history], default=0) + elif self.mode == "append": + # I put a max([]) default to -1 so that default "step" equals 0 + step = ( + max([metric.step for metric in metric_history], default=-1) + 1 + ) + else: + raise ValueError( + f"save_args['mode'] must be one of {self.SUPPORTED_SAVE_MODES}, got '{self.mode}' instead." + ) + + mlflow_client.log_metric( + run_id=run_id, + key=self.key, + value=data, + step=step, + **save_args, + ) diff --git a/tests/framework/hooks/test_all_hooks.py b/tests/framework/hooks/test_all_hooks.py index 95f30b8b..181ad1e7 100644 --- a/tests/framework/hooks/test_all_hooks.py +++ b/tests/framework/hooks/test_all_hooks.py @@ -28,12 +28,13 @@ def fake_fun(input): artifact = input - metric = { + metrics = { "metric1": {"value": 1.1, "step": 1}, "metric2": [{"value": 1.1, "step": 1}, {"value": 1.2, "step": 2}], } + metric = 1 model = 3 - return artifact, metric, model + return artifact, metrics, metric, metric, model @pytest.fixture @@ -87,6 +88,9 @@ def catalog_config(kedro_project_path): "metrics_data": { "type": "kedro_mlflow.io.metrics.MlflowMetricsDataSet", }, + "metric_data": { + "type": "kedro_mlflow.io.metrics.MlflowMetricDataSet", + }, "model": { "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet", "flavor": "mlflow.sklearn", @@ -212,7 +216,13 @@ def dummy_pipeline(): node( func=fake_fun, inputs=["params:a"], - outputs=["artifact_data", "metrics_data", "model"], + outputs=[ + "artifact_data", + "metrics_data", + "metric_data", + "metric_data_with_run_id", + "model", + ], ) ] ) @@ -273,9 +283,6 @@ def test_deactivated_tracking_but_not_for_given_pipeline( ] ) - # context = session.load_context() - # context.run(pipeline_name="pipeline_on") # this is a pipeline should be tracked - mock_session.run(pipeline_name="pipeline_on") all_runs_id_end = set( diff --git a/tests/framework/hooks/test_pipeline_hook.py b/tests/framework/hooks/test_pipeline_hook.py index 52281196..11cf9f3a 100644 --- a/tests/framework/hooks/test_pipeline_hook.py +++ b/tests/framework/hooks/test_pipeline_hook.py @@ -25,7 +25,7 @@ _format_conda_env, _generate_kedro_command, ) -from kedro_mlflow.io.metrics import MlflowMetricsDataSet +from kedro_mlflow.io.metrics import MlflowMetricDataSet, MlflowMetricsDataSet from kedro_mlflow.pipeline import pipeline_ml_factory from kedro_mlflow.pipeline.pipeline_ml import PipelineML @@ -215,9 +215,12 @@ def preprocess_fun(data): def train_fun(data, param): return 2 - def metric_fun(data, model): + def metrics_fun(data, model): return {"metric_key": {"value": 1.1, "step": 0}} + def metric_fun(data, model): + return 1.1 + def predict_fun(model, data): return data * model @@ -236,17 +239,29 @@ def predict_fun(model, data): tags=["training"], ), node( - func=metric_fun, + func=metrics_fun, inputs=["model", "data"], outputs="my_metrics", tags=["training"], ), node( - func=metric_fun, + func=metrics_fun, inputs=["model", "data"], outputs="another_metrics", tags=["training"], ), + node( + func=metric_fun, + inputs=["model", "data"], + outputs="my_metric", + tags=["training"], + ), + node( + func=metric_fun, + inputs=["model", "data"], + outputs="another_metric", + tags=["training"], + ), node( func=predict_fun, inputs=["model", "data"], @@ -281,6 +296,8 @@ def dummy_catalog(tmp_path): "model": PickleDataSet((tmp_path / "model.csv").as_posix()), "my_metrics": MlflowMetricsDataSet(), "another_metrics": MlflowMetricsDataSet(prefix="foo"), + "my_metric": MlflowMetricDataSet(), + "another_metric": MlflowMetricDataSet(key="foo"), } ) return dummy_catalog @@ -428,6 +445,8 @@ def test_mlflow_pipeline_hook_with_different_pipeline_types( # for metric assert dummy_catalog._data_sets["my_metrics"]._prefix == "my_metrics" assert dummy_catalog._data_sets["another_metrics"]._prefix == "foo" + assert dummy_catalog._data_sets["my_metric"].key == "my_metric" + assert dummy_catalog._data_sets["another_metric"].key == "foo" if isinstance(pipeline_to_run, PipelineML): trained_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") @@ -503,7 +522,7 @@ def test_mlflow_pipeline_hook_with_copy_mode( assert actual_copy_mode == expected -def test_mlflow_pipeline_hook_metrics_with_run_id( +def test_mlflow_pipeline_hook_metric_metrics_with_run_id( kedro_project_with_mlflow_conf, dummy_pipeline_ml, dummy_run_params ): @@ -528,6 +547,10 @@ def test_mlflow_pipeline_hook_metrics_with_run_id( "another_metrics": MlflowMetricsDataSet( run_id=existing_run_id, prefix="foo" ), + "my_metric": MlflowMetricDataSet(run_id=existing_run_id), + "another_metric": MlflowMetricDataSet( + run_id=existing_run_id, key="foo" + ), } ) @@ -578,8 +601,11 @@ def test_mlflow_pipeline_hook_metrics_with_run_id( # Check if metrics datasets have prefix with its names. # for metric assert all_runs_id == {current_run_id, existing_run_id} + print(run_data.metrics) assert run_data.metrics["my_metrics.metric_key"] == 1.1 assert run_data.metrics["foo.metric_key"] == 1.1 + assert run_data.metrics["my_metric"] == 1.1 + assert run_data.metrics["foo"] == 1.1 def test_mlflow_pipeline_hook_save_pipeline_ml_with_parameters( diff --git a/tests/io/metrics/test_mlflow_metric_dataset.py b/tests/io/metrics/test_mlflow_metric_dataset.py new file mode 100644 index 00000000..bc7b5f1f --- /dev/null +++ b/tests/io/metrics/test_mlflow_metric_dataset.py @@ -0,0 +1,202 @@ +import mlflow +import pytest +from kedro.io.core import DataSetError +from mlflow.tracking import MlflowClient + +from kedro_mlflow.io.metrics import MlflowMetricDataSet + + +@pytest.fixture +def mlflow_tracking_uri(tmp_path): + tracking_uri = (tmp_path / "mlruns").as_uri() + mlflow.set_tracking_uri(tracking_uri) + return tracking_uri + + +@pytest.fixture +def mlflow_client(mlflow_tracking_uri): + mlflow_client = MlflowClient(mlflow_tracking_uri) + return mlflow_client + + +def test_mlflow_wrong_save_mode(): + with pytest.raises(DataSetError, match=r"save_args\['mode'\] must be one of"): + metric_ds = MlflowMetricDataSet(key="my_metric", save_args={"mode": "bad_mode"}) + with mlflow.start_run(): + metric_ds.save(0.3) + + +def test_mlflow_metric_dataset_save_without_active_run_or_run_id(): + metric_ds = MlflowMetricDataSet(key="my_metric") + with pytest.raises( + DataSetError, + match="You must either specify a run_id or have a mlflow active run opened", + ): + metric_ds.save(0.3) + + +@pytest.mark.parametrize( + "save_args", + [ + (None), + ({}), + ({"mode": "append"}), + ({"mode": "overwrite"}), + ({"step": 2}), + ({"step": 2, "mode": "append"}), + ], +) +def test_mlflow_metric_dataset_save_with_active_run(mlflow_client, save_args): + metric_ds = MlflowMetricDataSet(key="my_metric", save_args=save_args) + with mlflow.start_run(): + metric_ds.save(0.3) + run_id = mlflow.active_run().info.run_id + metric_history = mlflow_client.get_metric_history(run_id, "my_metric") + + step = 0 if save_args is None else save_args.get("step", 0) + assert [ + (metric.key, metric.step, metric.value) for metric in metric_history + ] == [("my_metric", step, 0.3)] + + +@pytest.mark.parametrize( + "save_args", + [ + (None), + ({}), + ({"mode": "append"}), + ({"mode": "overwrite"}), + ({"step": 2}), + ({"step": 2, "mode": "append"}), + ], +) +def test_mlflow_metric_dataset_save_with_run_id(mlflow_client, save_args): + + # this time, the run is created first and closed + # the MlflowMetricDataSet should reopen it to interact + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + + metric_ds = MlflowMetricDataSet(run_id=run_id, key="my_metric", save_args=save_args) + metric_ds.save(0.3) + metric_history = mlflow_client.get_metric_history(run_id, "my_metric") + step = 0 if save_args is None else save_args.get("step", 0) + assert [(metric.key, metric.step, metric.value) for metric in metric_history] == [ + ("my_metric", step, 0.3) + ] + assert mlflow.active_run() is None # no run should be opened + + +def test_mlflow_metric_dataset_save_append_mode(mlflow_client): + + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + + metric_ds = MlflowMetricDataSet( + run_id=run_id, key="my_metric", save_args={"mode": "append"} + ) + + metric_ds.save(0.3) + metric_ds.save(1) + metric_history = mlflow_client.get_metric_history(run_id, "my_metric") + assert [(metric.key, metric.step, metric.value) for metric in metric_history] == [ + ("my_metric", 0, 0.3), + ("my_metric", 1, 1), + ] + + +def test_mlflow_metric_dataset_save_overwrite_mode(mlflow_client): + + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + + # overwrite is the default mode + metric_ds = MlflowMetricDataSet(run_id=run_id, key="my_metric") + + metric_ds.save(0.3) + metric_ds.save(1) + metric_history = mlflow_client.get_metric_history(run_id, "my_metric") + assert [(metric.key, metric.step, metric.value) for metric in metric_history] == [ + ("my_metric", 0, 0.3), + ("my_metric", 0, 1), # same step + ] + + +def test_mlflow_metric_dataset_load(): + + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + mlflow.log_metric(key="awesome_metric", value=0.1) + + # overwrite is the default mode + metric_ds = MlflowMetricDataSet(run_id=run_id, key="awesome_metric") + + assert metric_ds.load() == 0.1 + + +def test_mlflow_metric_dataset_load_last_logged_by_default_if_unordered(): + + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + mlflow.log_metric(key="awesome_metric", value=0.4, step=3) + mlflow.log_metric(key="awesome_metric", value=0.3, step=2) + mlflow.log_metric(key="awesome_metric", value=0.2, step=1) + mlflow.log_metric(key="awesome_metric", value=0.1, step=0) + + # overwrite is the default mode + metric_ds = MlflowMetricDataSet(run_id=run_id, key="awesome_metric") + + assert ( + metric_ds.load() == 0.1 + ) # the last value is retrieved even if it has a smaller step + + +def test_mlflow_metric_dataset_load_given_step(): + + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + mlflow.log_metric(key="awesome_metric", value=0.1, step=0) + mlflow.log_metric(key="awesome_metric", value=0.2, step=1) + mlflow.log_metric(key="awesome_metric", value=0.3, step=2) + mlflow.log_metric(key="awesome_metric", value=0.4, step=3) + + # overwrite is the default mode + metric_ds = MlflowMetricDataSet( + run_id=run_id, key="awesome_metric", load_args={"step": 2} + ) + + assert metric_ds.load() == 0.3 + + +def test_mlflow_metric_dataset_load_last_given_step_if_duplicated(): + + with mlflow.start_run(): + run_id = mlflow.active_run().info.run_id + mlflow.log_metric(key="awesome_metric", value=0.1, step=0) + mlflow.log_metric(key="awesome_metric", value=0.2, step=1) + mlflow.log_metric(key="awesome_metric", value=0.3, step=2) + mlflow.log_metric(key="awesome_metric", value=0.31, step=2) + mlflow.log_metric(key="awesome_metric", value=0.32, step=2) + mlflow.log_metric(key="awesome_metric", value=0.4, step=3) + + # overwrite is the default mode + metric_ds = MlflowMetricDataSet( + run_id=run_id, key="awesome_metric", load_args={"step": 2} + ) + + assert metric_ds.load() == 0.32 + + +def test_mlflow_metric_dataset_logging_deactivation(mlflow_tracking_uri): + metric_ds = MlflowMetricDataSet(key="inactive_metric") + metric_ds._logging_activated = False + with mlflow.start_run(): + metric_ds.save(1) + assert metric_ds._exists() is False + + +def test_mlflow_metric_logging_deactivation_is_bool(): + mlflow_metric_dataset = MlflowMetricDataSet(key="hello") + + with pytest.raises(ValueError, match="_logging_activated must be a boolean"): + mlflow_metric_dataset._logging_activated = "hello" diff --git a/tests/io/metrics/test_mlflow_metrics_dataset.py b/tests/io/metrics/test_mlflow_metrics_dataset.py index f88812af..82a44e02 100644 --- a/tests/io/metrics/test_mlflow_metrics_dataset.py +++ b/tests/io/metrics/test_mlflow_metrics_dataset.py @@ -171,7 +171,7 @@ def test_mlflow_metrics_dataset_fails_with_invalid_metric( ) # key: value is not valid, you must specify {key: {value, step}} -def test_mlflow_artifact_logging_deactivation(tracking_uri, metrics): +def test_mlflow_metrics_logging_deactivation(tracking_uri, metrics): mlflow_metrics_dataset = MlflowMetricsDataSet(prefix="hello") mlflow.set_tracking_uri(tracking_uri.as_uri())