From 60f4711bc419bbfcbff9952e71c46d12bba21b5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 5 Oct 2020 03:10:59 +0200 Subject: [PATCH 1/2] mock comet --- tests/loggers/test_comet.py | 95 ++++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 16e8d8551b6e5..89b13c920d4d9 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -9,27 +9,30 @@ from tests.base import EvalModelTemplate -def test_comet_logger_online(): +def _patch_comet_atexit(monkeypatch): + """ Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it. """ + import atexit + monkeypatch.setattr(atexit, "register", lambda _: None) + + +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_online(comet): """Test comet online with mocks.""" # Test api_key given - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment: logger = CometLogger(api_key='key', workspace='dummy-test', project_name='general') _ = logger.experiment - comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') + comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') # Test both given - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment: logger = CometLogger(save_dir='test', api_key='key', workspace='dummy-test', project_name='general') _ = logger.experiment - comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') - - # Test neither given - with pytest.raises(MisconfigurationException): - CometLogger(workspace='dummy-test', project_name='general') + comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') # Test already exists with patch('pytorch_lightning.loggers.comet.CometExistingExperiment') as comet_existing: @@ -55,56 +58,72 @@ def test_comet_logger_online(): api.assert_called_once_with('rest') -def test_comet_logger_experiment_name(): +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_no_api_key_given(comet): + """ Test that CometLogger fails to initialize if both api key and save_dir are missing. """ + with pytest.raises(MisconfigurationException): + comet.config.get_api_key.return_value = None + CometLogger(workspace='dummy-test', project_name='general') + + +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_experiment_name(comet): """Test that Comet Logger experiment name works correctly.""" api_key = "key" experiment_name = "My Name" # Test api_key given - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment: logger = CometLogger(api_key=api_key, experiment_name=experiment_name,) assert logger._experiment is None _ = logger.experiment - comet.assert_called_once_with(api_key=api_key, project_name=None) + comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) - comet().set_name.assert_called_once_with(experiment_name) + comet_experiment().set_name.assert_called_once_with(experiment_name) -def test_comet_logger_dirs_creation(tmpdir, monkeypatch): +@patch('pytorch_lightning.loggers.comet.CometOfflineExperiment') +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch): """ Test that the logger creates the folders and files in the right place. """ - # prevent comet logger from trying to print at exit, since - # pytest's stdout/stderr redirection breaks it - import atexit - - monkeypatch.setattr(atexit, 'register', lambda _: None) + _patch_comet_atexit(monkeypatch) + comet.config.get_api_key.return_value = None + comet.generate_guid.return_value = "4321" logger = CometLogger(project_name='test', save_dir=tmpdir) assert not os.listdir(tmpdir) assert logger.mode == 'offline' assert logger.save_dir == tmpdir + assert logger.name == 'test' + assert logger.version == "4321" _ = logger.experiment - version = logger.version - assert set(os.listdir(tmpdir)) == {f'{logger.experiment.id}.zip'} + + comet_experiment.assert_called_once_with(offline_directory=tmpdir, project_name='test') + + # mock return values of experiment + logger.experiment.id = '1' + logger.experiment.project_name = 'test' model = EvalModelTemplate() trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) trainer.fit(model) - assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints') + assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints') assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} -def test_comet_name_default(): +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_name_default(comet): """ Test that CometLogger.name don't create an Experiment and returns a default value. """ api_key = "key" - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment'): logger = CometLogger(api_key=api_key) assert logger._experiment is None @@ -114,13 +133,14 @@ def test_comet_name_default(): assert logger._experiment is None -def test_comet_name_project_name(): +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_name_project_name(comet): """ Test that CometLogger.name does not create an Experiment and returns project name if passed. """ api_key = "key" project_name = "My Project Name" - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment'): logger = CometLogger(api_key=api_key, project_name=project_name) assert logger._experiment is None @@ -130,13 +150,15 @@ def test_comet_name_project_name(): assert logger._experiment is None -def test_comet_version_without_experiment(): +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_version_without_experiment(comet): """ Test that CometLogger.version does not create an Experiment. """ api_key = "key" experiment_name = "My Name" + comet.generate_guid.return_value = "1234" - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment'): logger = CometLogger(api_key=api_key, experiment_name=experiment_name) assert logger._experiment is None @@ -152,17 +174,16 @@ def test_comet_version_without_experiment(): logger.reset_experiment() - second_version = logger.version + second_version = logger.version == "1234" assert second_version is not None assert second_version != first_version -def test_comet_epoch_logging(tmpdir, monkeypatch): +@patch("pytorch_lightning.loggers.comet.CometExperiment") +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch): """ Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """ - import atexit - - monkeypatch.setattr(atexit, "register", lambda _: None) - with patch("pytorch_lightning.loggers.comet.CometOfflineExperiment.log_metrics") as log_metrics: - logger = CometLogger(project_name="test", save_dir=tmpdir) - logger.log_metrics({"test": 1, "epoch": 1}, step=123) - log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) + _patch_comet_atexit(monkeypatch) + logger = CometLogger(project_name="test", save_dir=tmpdir) + logger.log_metrics({"test": 1, "epoch": 1}, step=123) + logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) \ No newline at end of file From 3b40c8d8133f553f24e28c446f0378b63ab7d5c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 5 Oct 2020 04:23:30 +0200 Subject: [PATCH 2/2] new line --- tests/loggers/test_comet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 89b13c920d4d9..9478efc4960ed 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -186,4 +186,4 @@ def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch): _patch_comet_atexit(monkeypatch) logger = CometLogger(project_name="test", save_dir=tmpdir) logger.log_metrics({"test": 1, "epoch": 1}, step=123) - logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) \ No newline at end of file + logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)