From f45f36820a29a9afd7719dc175543d9bcef103e5 Mon Sep 17 00:00:00 2001 From: Diedre Carmo Date: Tue, 10 Nov 2020 08:50:25 -0300 Subject: [PATCH] fix logged keys in mlflow logger (#4412) * [#4411] fix gpu_log_memory with mlflow logger * sanitize parenthesis instead of removing for all loggers * apply regex for mlflow key sanitization * replace ',' with '.' typo * add single warning and test Co-authored-by: Rohit Gupta Co-authored-by: chaton (cherry picked from commit 470e2945fc0fefa1b6d42f1d64fef9f22f5c75b3) --- pytorch_lightning/loggers/mlflow.py | 9 +++++++++ tests/loggers/test_mlflow.py | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index de915785dcb45..ee9f8f86cf247 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -16,6 +16,8 @@ MLflow ------ """ +import re +import warnings from argparse import Namespace from time import time from typing import Any, Dict, Optional, Union @@ -151,6 +153,13 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> if isinstance(v, str): log.warning(f'Discarding metric with string value {k}={v}.') continue + + new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k) + if k != new_k: + warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n", + f"Replacing {k} with {new_k}.")) + k = new_k + self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step) @rank_zero_only diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index db2c353dc4e2c..618f0ffe60903 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -137,7 +137,8 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} model = EvalModelTemplate() - trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3, + log_gpu_memory=True) trainer.fit(model) assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')