Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add run_name argument to the MLFlowLogger constructor #7622

Merged
merged 9 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))


- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))


### Deprecated


Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
try:
import mlflow
from mlflow.tracking import context, MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
# todo: there seems to be still some remaining import error with Conda env
except ImportError:
_MLFLOW_AVAILABLE = False
mlflow, MlflowClient, context = None, None, None
MLFLOW_RUN_NAME = "mlflow.runName"

# before v1.1.0
if hasattr(context, 'resolve_tags'):
Expand Down Expand Up @@ -85,6 +87,8 @@ def any_lightning_module_function_or_hook(self):

Args:
experiment_name: The name of the experiment
run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag.
If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`.
tracking_uri: Address of local or remote tracking server.
If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls
back to `file:<save_dir>`.
Expand All @@ -106,6 +110,7 @@ def any_lightning_module_function_or_hook(self):
def __init__(
self,
experiment_name: str = 'default',
run_name: Optional[str] = None,
tracking_uri: Optional[str] = os.getenv('MLFLOW_TRACKING_URI'),
tags: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = './mlruns',
Expand All @@ -124,6 +129,7 @@ def __init__(
self._experiment_name = experiment_name
self._experiment_id = None
self._tracking_uri = tracking_uri
self._run_name = run_name
self._run_id = None
self.tags = tags
self._prefix = prefix
Expand Down Expand Up @@ -155,6 +161,14 @@ def experiment(self) -> MlflowClient:
)

if self._run_id is None:
if self._run_name is not None:
self.tags = self.tags or {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we maybe lift this to the init already? So that self.tags is always a dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have a concern about this. The current self.tags looks to be a public attribute (it has no prefix '_'), so I worried that the user could set None to it after the initialization.

Note that the document does not say the tags are public, so I'm not sure it is intentional:
https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.mlflow.html#mlflow-logger

Another minor reason why I put this here is to minimize the side effect. If the run_name is not set explicitly, the self.tags still be 'None' after this PR. this is the same behavior as the current version.

I could not judge which is better, so any suggestions and comments are welcome!

if MLFLOW_RUN_NAME in self.tags:
log.warning(
f'The tag {MLFLOW_RUN_NAME} is found in tags. '
f'The value will be overridden by {self._run_name}.'
)
self.tags[MLFLOW_RUN_NAME] = self._run_name
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
self._run_id = run.info.run_id
return self._mlflow_client
Expand Down
28 changes: 28 additions & 0 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger
from pytorch_lightning.loggers.mlflow import MLFLOW_RUN_NAME, resolve_tags
from tests.helpers import BoringModel


Expand Down Expand Up @@ -85,6 +86,33 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
assert logger3.run_id == "run-id-3"


@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
def test_mlflow_run_name_setting(client, mlflow, tmpdir):
""" Test that the run_name argument makes the MLFLOW_RUN_NAME tag. """

tags = resolve_tags({MLFLOW_RUN_NAME: 'run-name-1'})

# run_name is appended to tags
logger = MLFlowLogger('test', run_name='run-name-1', save_dir=tmpdir)
logger = mock_mlflow_run_creation(logger, experiment_id='exp-id')
_ = logger.experiment
client.return_value.create_run.assert_called_with(experiment_id='exp-id', tags=tags)

# run_name overrides tags[MLFLOW_RUN_NAME]
logger = MLFlowLogger('test', run_name='run-name-1', tags={MLFLOW_RUN_NAME: "run-name-2"}, save_dir=tmpdir)
logger = mock_mlflow_run_creation(logger, experiment_id='exp-id')
_ = logger.experiment
client.return_value.create_run.assert_called_with(experiment_id='exp-id', tags=tags)

# default run_name (= None) does not append new tag
logger = MLFlowLogger('test', save_dir=tmpdir)
logger = mock_mlflow_run_creation(logger, experiment_id='exp-id')
_ = logger.experiment
default_tags = resolve_tags(None)
client.return_value.create_run.assert_called_with(experiment_id='exp-id', tags=default_tags)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_log_dir(client, mlflow, tmpdir):
Expand Down