From 7eafd8eac63518a42075279a5e6855085c8bbe3d Mon Sep 17 00:00:00 2001 From: i-aki-y Date: Fri, 21 May 2021 17:17:32 +0900 Subject: [PATCH] Add run_name argument to the MLFlowLogger constructor (#7622) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add run_name argument to the MLFlowLogger * Update CHANGELOG * Fix unnecessary line * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix style by using yapf * Fix import error when mlflow is not installed * Update CHANGELOG.md * Update tests/loggers/test_mlflow.py Co-authored-by: akiyuki ishikawa Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 +++ pytorch_lightning/loggers/mlflow.py | 14 ++++++++++++++ tests/loggers/test_mlflow.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b7178a4b01ea0..8696f739d4071 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457)) +- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622)) + + ### Deprecated diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index fbcd4bbcc5183..1426adbe1104a 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -31,10 +31,12 @@ try: import mlflow from mlflow.tracking import context, MlflowClient + from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME # todo: there seems to be still some remaining import error with Conda env except ImportError: _MLFLOW_AVAILABLE = False mlflow, MlflowClient, context = None, None, None + MLFLOW_RUN_NAME = "mlflow.runName" # before v1.1.0 if hasattr(context, 'resolve_tags'): @@ -85,6 +87,8 @@ def any_lightning_module_function_or_hook(self): Args: experiment_name: The name of the experiment + run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag. + If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`. tracking_uri: Address of local or remote tracking server. If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls back to `file:`. @@ -106,6 +110,7 @@ def any_lightning_module_function_or_hook(self): def __init__( self, experiment_name: str = 'default', + run_name: Optional[str] = None, tracking_uri: Optional[str] = os.getenv('MLFLOW_TRACKING_URI'), tags: Optional[Dict[str, Any]] = None, save_dir: Optional[str] = './mlruns', @@ -124,6 +129,7 @@ def __init__( self._experiment_name = experiment_name self._experiment_id = None self._tracking_uri = tracking_uri + self._run_name = run_name self._run_id = None self.tags = tags self._prefix = prefix @@ -155,6 +161,14 @@ def experiment(self) -> MlflowClient: ) if self._run_id is None: + if self._run_name is not None: + self.tags = self.tags or {} + if MLFLOW_RUN_NAME in self.tags: + log.warning( + f'The tag {MLFLOW_RUN_NAME} is found in tags. ' + f'The value will be overridden by {self._run_name}.' + ) + self.tags[MLFLOW_RUN_NAME] = self._run_name run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags)) self._run_id = run.info.run_id return self._mlflow_client diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 35bad766798b1..d798cb9f16f7e 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -19,6 +19,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger +from pytorch_lightning.loggers.mlflow import MLFLOW_RUN_NAME, resolve_tags from tests.helpers import BoringModel @@ -85,6 +86,33 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir): assert logger3.run_id == "run-id-3" +@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') +@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') +def test_mlflow_run_name_setting(client, mlflow, tmpdir): + """ Test that the run_name argument makes the MLFLOW_RUN_NAME tag. """ + + tags = resolve_tags({MLFLOW_RUN_NAME: 'run-name-1'}) + + # run_name is appended to tags + logger = MLFlowLogger('test', run_name='run-name-1', save_dir=tmpdir) + logger = mock_mlflow_run_creation(logger, experiment_id='exp-id') + _ = logger.experiment + client.return_value.create_run.assert_called_with(experiment_id='exp-id', tags=tags) + + # run_name overrides tags[MLFLOW_RUN_NAME] + logger = MLFlowLogger('test', run_name='run-name-1', tags={MLFLOW_RUN_NAME: "run-name-2"}, save_dir=tmpdir) + logger = mock_mlflow_run_creation(logger, experiment_id='exp-id') + _ = logger.experiment + client.return_value.create_run.assert_called_with(experiment_id='exp-id', tags=tags) + + # default run_name (= None) does not append new tag + logger = MLFlowLogger('test', save_dir=tmpdir) + logger = mock_mlflow_run_creation(logger, experiment_id='exp-id') + _ = logger.experiment + default_tags = resolve_tags(None) + client.return_value.create_run.assert_called_with(experiment_id='exp-id', tags=default_tags) + + @mock.patch("pytorch_lightning.loggers.mlflow.mlflow") @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_log_dir(client, mlflow, tmpdir):