Skip to content

Commit

Permalink
Mocking Loggers (part 3b, comet) (#3853)
Browse files Browse the repository at this point in the history
* ref

* Mocking Loggers (part 3c, comet) (#3859)

* mock comet

* new line
  • Loading branch information
awaelchli committed Oct 6, 2020
1 parent 2119184 commit 893bed7
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 56 deletions.
50 changes: 23 additions & 27 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,25 @@
from typing import Any, Dict, Optional, Union

try:
from comet_ml import BaseExperiment as CometBaseExperiment
import comet_ml

except ModuleNotFoundError: # pragma: no-cover
comet_ml = None
CometExperiment = None
CometExistingExperiment = None
CometOfflineExperiment = None
API = None
generate_guid = None
else:
from comet_ml import ExistingExperiment as CometExistingExperiment
from comet_ml import Experiment as CometExperiment
from comet_ml import OfflineExperiment as CometOfflineExperiment
from comet_ml import generate_guid

try:
from comet_ml.api import API
except ImportError: # pragma: no-cover
# For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300
from comet_ml.papi import API # pragma: no-cover
from comet_ml.config import get_api_key, get_config
except ImportError: # pragma: no-cover
CometExperiment = None
CometExistingExperiment = None
CometOfflineExperiment = None
CometBaseExperiment = None
API = None
generate_guid = None
_COMET_AVAILABLE = False
else:
_COMET_AVAILABLE = True

import torch
from torch import is_tensor
Expand Down Expand Up @@ -117,17 +114,17 @@ class CometLogger(LightningLoggerBase):
"""

def __init__(
self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
project_name: Optional[str] = None,
rest_api_key: Optional[str] = None,
experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None,
offline: bool = False,
**kwargs
self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
project_name: Optional[str] = None,
rest_api_key: Optional[str] = None,
experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None,
offline: bool = False,
**kwargs
):
if not _COMET_AVAILABLE:
if comet_ml is None:
raise ImportError(
"You want to use `comet_ml` logger which is not installed yet,"
" install it with `pip install comet-ml`."
Expand All @@ -136,7 +133,7 @@ def __init__(
self._experiment = None

# Determine online or offline mode based on which arguments were passed to CometLogger
api_key = api_key or get_api_key(None, get_config())
api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config())

if api_key is not None and save_dir is not None:
self.mode = "offline" if offline else "online"
Expand Down Expand Up @@ -173,7 +170,7 @@ def __init__(

@property
@rank_zero_experiment
def experiment(self) -> CometBaseExperiment:
def experiment(self):
r"""
Actual Comet object. To use Comet features in your
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
Expand Down Expand Up @@ -236,7 +233,6 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti

metrics_without_epoch = metrics.copy()
epoch = metrics_without_epoch.pop('epoch', None)

self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch)

def reset_experiment(self):
Expand Down Expand Up @@ -284,7 +280,7 @@ def version(self) -> str:
return self._future_experiment_key

# Pre-generate an experiment key
self._future_experiment_key = generate_guid()
self._future_experiment_key = comet_ml.generate_guid()

return self._future_experiment_key

Expand Down
80 changes: 51 additions & 29 deletions tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,24 @@ def _patch_comet_atexit(monkeypatch):
monkeypatch.setattr(atexit, "register", lambda _: None)


def test_comet_logger_online():
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_online(comet):
"""Test comet online with mocks."""
# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(api_key='key', workspace='dummy-test', project_name='general')

_ = logger.experiment

comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')
comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')

# Test both given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(save_dir='test', api_key='key', workspace='dummy-test', project_name='general')

_ = logger.experiment

comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')

# Test neither given
with pytest.raises(MisconfigurationException):
CometLogger(workspace='dummy-test', project_name='general')
comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')

# Test already exists
with patch('pytorch_lightning.loggers.comet.CometExistingExperiment') as comet_existing:
Expand All @@ -61,52 +58,73 @@ def test_comet_logger_online():
api.assert_called_once_with('rest')


def test_comet_logger_experiment_name():
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_no_api_key_given(comet):
""" Test that CometLogger fails to initialize if both api key and save_dir are missing. """
with pytest.raises(MisconfigurationException):
comet.config.get_api_key.return_value = None
CometLogger(workspace='dummy-test', project_name='general')


@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_experiment_name(comet):
"""Test that Comet Logger experiment name works correctly."""

api_key = "key"
experiment_name = "My Name"

# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(api_key=api_key, experiment_name=experiment_name,)

assert logger._experiment is None

_ = logger.experiment

comet.assert_called_once_with(api_key=api_key, project_name=None)
comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)

comet().set_name.assert_called_once_with(experiment_name)
comet_experiment().set_name.assert_called_once_with(experiment_name)


def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
@patch('pytorch_lightning.loggers.comet.CometOfflineExperiment')
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch):
""" Test that the logger creates the folders and files in the right place. """
_patch_comet_atexit(monkeypatch)

comet.config.get_api_key.return_value = None
comet.generate_guid.return_value = "4321"

logger = CometLogger(project_name='test', save_dir=tmpdir)
assert not os.listdir(tmpdir)
assert logger.mode == 'offline'
assert logger.save_dir == tmpdir
assert logger.name == 'test'
assert logger.version == "4321"

_ = logger.experiment
version = logger.version
assert set(os.listdir(tmpdir)) == {f'{logger.experiment.id}.zip'}

comet_experiment.assert_called_once_with(offline_directory=tmpdir, project_name='test')

# mock return values of experiment
logger.experiment.id = '1'
logger.experiment.project_name = 'test'

model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints')
assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}


def test_comet_name_default():
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_name_default(comet):
""" Test that CometLogger.name don't create an Experiment and returns a default value. """

api_key = "key"

with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key)

assert logger._experiment is None
Expand All @@ -116,13 +134,14 @@ def test_comet_name_default():
assert logger._experiment is None


def test_comet_name_project_name():
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_name_project_name(comet):
""" Test that CometLogger.name does not create an Experiment and returns project name if passed. """

api_key = "key"
project_name = "My Project Name"

with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, project_name=project_name)

assert logger._experiment is None
Expand All @@ -132,13 +151,15 @@ def test_comet_name_project_name():
assert logger._experiment is None


def test_comet_version_without_experiment():
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_version_without_experiment(comet):
""" Test that CometLogger.version does not create an Experiment. """

api_key = "key"
experiment_name = "My Name"
comet.generate_guid.return_value = "1234"

with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)

assert logger._experiment is None
Expand All @@ -154,15 +175,16 @@ def test_comet_version_without_experiment():

logger.reset_experiment()

second_version = logger.version
second_version = logger.version == "1234"
assert second_version is not None
assert second_version != first_version


def test_comet_epoch_logging(tmpdir, monkeypatch):
@patch("pytorch_lightning.loggers.comet.CometExperiment")
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch):
""" Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """
_patch_comet_atexit(monkeypatch)
with patch("pytorch_lightning.loggers.comet.CometOfflineExperiment.log_metrics") as log_metrics:
logger = CometLogger(project_name="test", save_dir=tmpdir)
logger.log_metrics({"test": 1, "epoch": 1}, step=123)
log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)
logger = CometLogger(project_name="test", save_dir=tmpdir)
logger.log_metrics({"test": 1, "epoch": 1}, step=123)
logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)

0 comments on commit 893bed7

Please sign in to comment.