Skip to content

Commit

Permalink
Add run_name argument to the MLFlowLogger constructor (#7622)
Browse files Browse the repository at this point in the history
* Add run_name argument to the MLFlowLogger

* Update CHANGELOG

* Fix unnecessary line

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix style by using yapf

* Fix import error when mlflow is not installed

* Update CHANGELOG.md

* Update tests/loggers/test_mlflow.py

Co-authored-by: akiyuki ishikawa <aki.y.ishikwa@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people authored May 21, 2021
1 parent 94ef17c commit 7eafd8e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))


- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))


### Deprecated


Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
try:
import mlflow
from mlflow.tracking import context, MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
# todo: there seems to be still some remaining import error with Conda env
except ImportError:
_MLFLOW_AVAILABLE = False
mlflow, MlflowClient, context = None, None, None
MLFLOW_RUN_NAME = "mlflow.runName"

# before v1.1.0
if hasattr(context, 'resolve_tags'):
Expand Down Expand Up @@ -85,6 +87,8 @@ def any_lightning_module_function_or_hook(self):
Args:
experiment_name: The name of the experiment
run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag.
If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`.
tracking_uri: Address of local or remote tracking server.
If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls
back to `file:<save_dir>`.
Expand All @@ -106,6 +110,7 @@ def any_lightning_module_function_or_hook(self):
def __init__(
self,
experiment_name: str = 'default',
run_name: Optional[str] = None,
tracking_uri: Optional[str] = os.getenv('MLFLOW_TRACKING_URI'),
tags: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = './mlruns',
Expand All @@ -124,6 +129,7 @@ def __init__(
self._experiment_name = experiment_name
self._experiment_id = None
self._tracking_uri = tracking_uri
self._run_name = run_name
self._run_id = None
self.tags = tags
self._prefix = prefix
Expand Down Expand Up @@ -155,6 +161,14 @@ def experiment(self) -> MlflowClient:
)

if self._run_id is None:
if self._run_name is not None:
self.tags = self.tags or {}
if MLFLOW_RUN_NAME in self.tags:
log.warning(
f'The tag {MLFLOW_RUN_NAME} is found in tags. '
f'The value will be overridden by {self._run_name}.'
)
self.tags[MLFLOW_RUN_NAME] = self._run_name
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
self._run_id = run.info.run_id
return self._mlflow_client
Expand Down
28 changes: 28 additions & 0 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger
from pytorch_lightning.loggers.mlflow import MLFLOW_RUN_NAME, resolve_tags
from tests.helpers import BoringModel


Expand Down Expand Up @@ -85,6 +86,33 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
assert logger3.run_id == "run-id-3"


@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
def test_mlflow_run_name_setting(client, mlflow, tmpdir):
""" Test that the run_name argument makes the MLFLOW_RUN_NAME tag. """

tags = resolve_tags({MLFLOW_RUN_NAME: 'run-name-1'})

# run_name is appended to tags
logger = MLFlowLogger('test', run_name='run-name-1', save_dir=tmpdir)
logger = mock_mlflow_run_creation(logger, experiment_id='exp-id')
_ = logger.experiment
client.return_value.create_run.assert_called_with(experiment_id='exp-id', tags=tags)

# run_name overrides tags[MLFLOW_RUN_NAME]
logger = MLFlowLogger('test', run_name='run-name-1', tags={MLFLOW_RUN_NAME: "run-name-2"}, save_dir=tmpdir)
logger = mock_mlflow_run_creation(logger, experiment_id='exp-id')
_ = logger.experiment
client.return_value.create_run.assert_called_with(experiment_id='exp-id', tags=tags)

# default run_name (= None) does not append new tag
logger = MLFlowLogger('test', save_dir=tmpdir)
logger = mock_mlflow_run_creation(logger, experiment_id='exp-id')
_ = logger.experiment
default_tags = resolve_tags(None)
client.return_value.create_run.assert_called_with(experiment_id='exp-id', tags=default_tags)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_log_dir(client, mlflow, tmpdir):
Expand Down

0 comments on commit 7eafd8e

Please sign in to comment.