diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index de915785dcb45..ee9f8f86cf247 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -16,6 +16,8 @@ MLflow ------ """ +import re +import warnings from argparse import Namespace from time import time from typing import Any, Dict, Optional, Union @@ -151,6 +153,13 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> if isinstance(v, str): log.warning(f'Discarding metric with string value {k}={v}.') continue + + new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k) + if k != new_k: + warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n", + f"Replacing {k} with {new_k}.")) + k = new_k + self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step) @rank_zero_only diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index db2c353dc4e2c..618f0ffe60903 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -137,7 +137,8 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} model = EvalModelTemplate() - trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3, + log_gpu_memory=True) trainer.fit(model) assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')