Skip to content

Commit

Permalink
fix logged keys in mlflow logger (#4412)
Browse files Browse the repository at this point in the history
* [#4411] fix gpu_log_memory with mlflow logger

* sanitize parenthesis instead of removing for all loggers

* apply regex for mlflow key sanitization

* replace ',' with '.' typo

* add single warning and test

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
  • Loading branch information
3 people committed Nov 10, 2020
1 parent 11415fa commit 470e294
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
9 changes: 9 additions & 0 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
MLflow
------
"""
import re
import warnings
from argparse import Namespace
from time import time
from typing import Any, Dict, Optional, Union
Expand Down Expand Up @@ -151,6 +153,13 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
if isinstance(v, str):
log.warning(f'Discarding metric with string value {k}={v}.')
continue

new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
if k != new_k:
warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n",
f"Replacing {k} with {new_k}."))
k = new_k

self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

@rank_zero_only
Expand Down
3 changes: 2 additions & 1 deletion tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def test_mlflow_logger_dirs_creation(tmpdir):
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}

model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3,
log_gpu_memory=True)
trainer.fit(model)
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')
Expand Down

0 comments on commit 470e294

Please sign in to comment.