Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed a crash bug in MLFlow logger #4716

Merged
merged 9 commits into from
Nov 24, 2020
11 changes: 6 additions & 5 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
-------------
"""
import re
import warnings
from argparse import Namespace
from time import time
from typing import Any, Dict, Optional, Union
Expand All @@ -32,7 +31,7 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn

LOCAL_FILE_URI_PREFIX = "file:"

Expand Down Expand Up @@ -165,9 +164,11 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->

new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
if k != new_k:
warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n",
f"Replacing {k} with {new_k}."))
k = new_k
rank_zero_warn(
"MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."
f" Replacing {k} with {new_k}.", RuntimeWarning
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)
k = new_k

self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

Expand Down
16 changes: 16 additions & 0 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,24 @@ def test_mlflow_logger_dirs_creation(tmpdir):
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
"""
Test that the logger experiment_id retrieved only once.
"""
logger = MLFlowLogger('test', save_dir=tmpdir)
_ = logger.experiment
_ = logger.experiment
_ = logger.experiment
assert logger.experiment.get_experiment_by_name.call_count == 1


@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
"""
Test that the logger raises warning with special characters not accepted by MLFlow.
"""
logger = MLFlowLogger('test', save_dir=tmpdir)
metrics = {'[some_metric]': 10}

with pytest.warns(RuntimeWarning, match='special characters in metric name'):
logger.log_metrics(metrics)