diff --git a/CHANGES.md b/CHANGES.md index 0965c63a6..30a7b3ee7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- `NeptuneLogger` was updated to work with recent versions of Neptune client (v0.14.3 or higher); it now logs some additional data, including the model summary, configuration, and learning rate (when available) (#906) + ### Fixed ## [0.12.0] - 2022-10-07 diff --git a/requirements-dev.txt b/requirements-dev.txt index 6d73c2968..0943368fa 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,8 +5,7 @@ future>=0.17.1 gpytorch>=1.5 jupyter matplotlib>=2.0.2 -mlflow -neptune-client>=0.4.103 +neptune-client>=0.14.3 numpydoc openpyxl pandas diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index f12fc2f91..b2393eb44 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -65,86 +65,88 @@ def on_epoch_end(self, net, **kwargs): class NeptuneLogger(Callback): - """Logs results from history to Neptune + """Logs model metadata and training metrics to Neptune. - Neptune is a lightweight experiment tracking tool. + Neptune is a lightweight experiment-tracking tool. You can read more about it here: https://neptune.ai Use this callback to automatically log all interesting values from your net's history to Neptune. The best way to log additional information is to log directly to the - experiment object or subclass the ``on_*`` methods. + run object. - To monitor resource consumption install psutil + To monitor resource consumption, install psutil: - >>> python -m pip install psutil + $ python -m pip install psutil You can view example experiment logs here: - https://ui.neptune.ai/o/shared/org/skorch-integration/e/SKOR-13/charts + https://app.neptune.ai/shared/skorch-integration/e/SKOR-23/ Examples -------- - >>> # Install neptune - >>> python -m pip install neptune-client - >>> # Create a neptune experiment object - >>> import neptune - ... - ... # We are using api token for an anonymous user. - ... # For your projects use the token associated with your neptune.ai account - >>> neptune.init(api_token='ANONYMOUS', - ... project_qualified_name='shared/skorch-integration') - ... - ... experiment = neptune.create_experiment( - ... name='skorch-basic-example', - ... params={'max_epochs': 20, - ... 'lr': 0.01}, - ... upload_source_files=['skorch_example.py']) + $ # Install Neptune + $ python -m pip install neptune-client - >>> # Create a neptune_logger callback - >>> neptune_logger = NeptuneLogger(experiment, close_after_train=False) - - >>> # Pass a logger to net callbacks argument + >>> # Create a Neptune run + >>> import neptune.new as neptune + >>> from neptune.new.types import File + ... + ... # This example uses the API token for anonymous users. + ... # For your own projects, use the token associated with your neptune.ai account. + >>> run = neptune.init_run( + ... api_token=neptune.ANONYMOUS_API_TOKEN, + ... project='shared/skorch-integration', + ... name='skorch-basic-example', + ... source_files=['skorch_example.py'], + ... ) + + >>> # Create a NeptuneLogger callback + >>> neptune_logger = NeptuneLogger(run, close_after_train=False) + + >>> # Pass the logger to the net callbacks argument >>> net = NeuralNetClassifier( ... ClassifierModule, ... max_epochs=20, ... lr=0.01, - ... callbacks=[neptune_logger]) + ... callbacks=[neptune_logger, Checkpoint(dirname="./checkpoints")]) + >>> net.fit(X, y) + + >>> # Save the checkpoints to Neptune + >>> neptune_logger.run["checkpoints].upload_files("./checkpoints") >>> # Log additional metrics after training has finished >>> from sklearn.metrics import roc_auc_score - ... y_pred = net.predict_proba(X) - ... auc = roc_auc_score(y, y_pred[:, 1]) + ... y_proba = net.predict_proba(X) + ... auc = roc_auc_score(y, y_proba[:, 1]) ... - ... neptune_logger.experiment.log_metric('roc_auc_score', auc) + ... neptune_logger.run["roc_auc_score"].log(auc) - >>> # log charts like ROC curve - ... from scikitplot.metrics import plot_roc - ... import matplotlib.pyplot as plt + >>> # Log charts, such as an ROC curve + >>> from sklearn.metrics import RocCurveDisplay ... - ... fig, ax = plt.subplots(figsize=(16, 12)) - ... plot_roc(y, y_pred, ax=ax) - ... neptune_logger.experiment.log_image('roc_curve', fig) + >>> roc_plot = RocCurveDisplay.from_estimator(net, X, y) + >>> neptune_logger.run["roc_curve"].upload(File.as_html(roc_plot.figure_)) - >>> # log net object after training + >>> # Log the net object after training ... net.save_params(f_params='basic_model.pkl') - ... neptune_logger.experiment.log_artifact('basic_model.pkl') + ... neptune_logger.run["basic_model"].upload(File('basic_model.pkl')) - >>> # close experiment - ... neptune_logger.experiment.stop() + >>> # Close the run + ... neptune_logger.run.stop() Parameters ---------- - experiment : neptune.experiments.Experiment - Instantiated ``Experiment`` class. + run : neptune.new.Run + Instantiated ``Run`` class. log_on_batch_end : bool (default=False) Whether to log loss and other metrics on batch level. close_after_train : bool (default=True) - Whether to close the ``Experiment`` object once training + Whether to close the ``Run`` object once training finishes. Set this parameter to False if you want to continue - logging to the same Experiment or if you use it as a context + logging to the same run or if you use it as a context manager. keys_ignored : str or list of str (default=None) @@ -152,60 +154,117 @@ class NeptuneLogger(Callback): addition to the keys provided by the user, keys such as those starting with ``'event_'`` or ending on ``'_best'`` are ignored by default. + base_namespace: str + Namespace (folder) under which all metadata logged by the ``NeptuneLogger`` + will be stored. Defaults to "training". + Attributes ---------- - first_batch_ : bool - Helper attribute that is set to True at initialization and changes - to False on first batch end. Can be used when we want to log things - exactly once. - .. _Neptune: https://www.neptune.ai """ def __init__( self, - experiment, + run, + *, log_on_batch_end=False, close_after_train=True, keys_ignored=None, + base_namespace='training', ): - self.experiment = experiment + self.run = run self.log_on_batch_end = log_on_batch_end self.close_after_train = close_after_train self.keys_ignored = keys_ignored + self.base_namespace = base_namespace - def initialize(self): - self.first_batch_ = True + @property + def _metric_logger(self): + return self.run[self._base_namespace] + @staticmethod + def _get_obj_name(obj): + return type(obj).__name__ + + def initialize(self): keys_ignored = self.keys_ignored if isinstance(keys_ignored, str): keys_ignored = [keys_ignored] self.keys_ignored_ = set(keys_ignored or []) self.keys_ignored_.add('batches') + + if self.base_namespace.endswith("/"): + self._base_namespace = self.base_namespace[:-1] + else: + self._base_namespace = self.base_namespace + return self + def on_train_begin(self, net, X, y, **kwargs): + # TODO: we might want to improve logging of the multi-module net objects, see: + # https://github.com/skorch-dev/skorch/pull/906#discussion_r993514643 + + self._metric_logger['model/model_type'] = self._get_obj_name(net.module_) + self._metric_logger['model/summary'] = self._model_summary_file(net.module_) + + self._metric_logger['config/optimizer'] = self._get_obj_name(net.optimizer_) + self._metric_logger['config/criterion'] = self._get_obj_name(net.criterion_) + self._metric_logger['config/lr'] = net.lr + self._metric_logger['config/epochs'] = net.max_epochs + self._metric_logger['config/batch_size'] = net.batch_size + self._metric_logger['config/device'] = net.device + def on_batch_end(self, net, **kwargs): if self.log_on_batch_end: batch_logs = net.history[-1]['batches'][-1] for key in filter_log_keys(batch_logs.keys(), self.keys_ignored_): - self.experiment.log_metric(key, batch_logs[key]) - - self.first_batch_ = False + self._log_metric(key, batch_logs, batch=True) def on_epoch_end(self, net, **kwargs): """Automatically log values from the last history step.""" - history = net.history - epoch_logs = history[-1] - epoch = epoch_logs['epoch'] + epoch_logs = net.history[-1] for key in filter_log_keys(epoch_logs.keys(), self.keys_ignored_): - self.experiment.log_metric(key, x=epoch, y=epoch_logs[key]) + self._log_metric(key, epoch_logs, batch=False) def on_train_end(self, net, **kwargs): + try: + self._metric_logger['train/epoch/event_lr'].log(net.history[:, 'event_lr']) + except KeyError: + pass if self.close_after_train: - self.experiment.stop() + self.run.stop() + + def _log_metric(self, name, logs, batch): + kind, _, key = name.partition('_') + + if not key: + key = 'epoch_duration' if kind == 'dur' else kind + self._metric_logger[key].log(logs[name]) + else: + if kind == 'valid': + kind = 'validation' + + if batch: + granularity = 'batch' + else: + granularity = 'epoch' + + # for example: train / epoch / loss + self._metric_logger[kind][granularity][key].log(logs[name]) + + @staticmethod + def _model_summary_file(model): + try: + # neptune-client=0.9.0+ package structure + from neptune.new.types import File + except ImportError: + # neptune-client>=1.0.0 package structure + from neptune.types import File + + return File.from_content(str(model), extension='txt') class WandbLogger(Callback): diff --git a/skorch/tests/callbacks/test_logging.py b/skorch/tests/callbacks/test_logging.py index 83a6ce5d4..13532472d 100644 --- a/skorch/tests/callbacks/test_logging.py +++ b/skorch/tests/callbacks/test_logging.py @@ -2,6 +2,8 @@ from functools import partial import os +import tempfile +import unittest.mock from unittest.mock import Mock from unittest.mock import call, patch @@ -16,10 +18,12 @@ from skorch.tests.conftest import tensorboard_installed from skorch.tests.conftest import mlflow_installed - @pytest.mark.skipif( not neptune_installed, reason='neptune is not installed') class TestNeptune: + # fields logged by on_train_begin and on_train_end + NUM_BASE_METRICS = 9 + @pytest.fixture def net_cls(self): from skorch import NeuralNetClassifier @@ -38,18 +42,24 @@ def neptune_logger_cls(self): return NeptuneLogger @pytest.fixture - def neptune_experiment_cls(self): - import neptune - neptune.init(project_qualified_name="tests/dry-run", - backend=neptune.OfflineBackend()) - return neptune.create_experiment + def neptune_run_object(self): + try: + # neptune-client=0.9.0+ package structure + import neptune.new as neptune + except ImportError: + # neptune-client>=1.0.0 package structure + import neptune + + run = neptune.init_run( + project="tests/dry-run", + mode="offline", + ) + return run @pytest.fixture - def mock_experiment(self, neptune_experiment_cls): - mock = Mock(spec=neptune_experiment_cls) - mock.log_metric = Mock() - mock.stop = Mock() - return mock + def mock_experiment(self, neptune_run_object): + with neptune_run_object as run: + return unittest.mock.create_autospec(run) @pytest.fixture def net_fitted( @@ -69,6 +79,10 @@ def net_fitted( def test_experiment_closed_automatically(self, net_fitted, mock_experiment): assert mock_experiment.stop.call_count == 1 + def test_experiment_log_call_counts(self, net_fitted, mock_experiment): + # (3 x dur + 3 x train_loss + 3 x valid_loss + 3 x valid_acc = 12) + base metrics + assert mock_experiment.__getitem__.call_count == 12 + self.NUM_BASE_METRICS + def test_experiment_not_closed( self, net_cls, @@ -103,10 +117,8 @@ def test_ignore_keys( max_epochs=3, ).fit(*data) - # 3 epochs x 2 epoch metrics = 6 calls - assert mock_experiment.log_metric.call_count == 6 - call_args = [args[0][0] for args in mock_experiment.log_metric.call_args_list] - assert 'valid_loss' not in call_args + # (3 epochs x 2 epoch metrics = 6 calls) + base metrics + assert mock_experiment.__getitem__.call_count == 6 + self.NUM_BASE_METRICS def test_keys_ignored_is_string(self, neptune_logger_cls, mock_experiment): npt = neptune_logger_cls( @@ -120,15 +132,23 @@ def test_fit_with_real_experiment( classifier_module, data, neptune_logger_cls, - neptune_experiment_cls, + neptune_run_object, ): net = net_cls( classifier_module, - callbacks=[neptune_logger_cls(neptune_experiment_cls())], + callbacks=[neptune_logger_cls(neptune_run_object)], max_epochs=5, ) net.fit(*data) + assert neptune_run_object.exists('training/epoch_duration') + assert neptune_run_object.exists('training/train/epoch/loss') + assert neptune_run_object.exists('training/validation/epoch/loss') + assert neptune_run_object.exists('training/validation/epoch/acc') + + # Checkpoint callback was not used + assert not neptune_run_object.exists('training/model/checkpoint') + def test_log_on_batch_level_on( self, net_cls, @@ -146,9 +166,9 @@ def test_log_on_batch_level_on( ) net.fit(*data) - # 5 epochs x (40/4 batches x 2 batch metrics + 2 epoch metrics) = 110 calls - assert mock_experiment.log_metric.call_count == 110 - mock_experiment.log_metric.assert_any_call('train_batch_size', 4) + # (5 epochs x (40/4 batches x 2 batch metrics + 2 epoch metrics) = 110 calls) + base metrics + assert mock_experiment.__getitem__.call_count == 110 + self.NUM_BASE_METRICS + mock_experiment['training']['train']['batch']['batch_size'].log.assert_any_call(4) def test_log_on_batch_level_off( self, @@ -167,31 +187,59 @@ def test_log_on_batch_level_off( ) net.fit(*data) - # 5 epochs x 2 epoch metrics = 10 calls - assert mock_experiment.log_metric.call_count == 10 - call_args_list = mock_experiment.log_metric.call_args_list - assert call('train_batch_size', 4) not in call_args_list + # (5 epochs x 2 epoch metrics = 10 calls) + base metrics + assert mock_experiment.__getitem__.call_count == 10 + self.NUM_BASE_METRICS - def test_first_batch_flag( + call_args = mock_experiment['training']['train'].__getitem__.call_args_list + assert call('epoch') in call_args + assert call('batch') not in call_args + + call_args = mock_experiment['training']['validation'].__getitem__.call_args_list + assert call('epoch') in call_args + assert call('batch') not in call_args + + def test_fit_with_real_experiment_saving_checkpoints( self, net_cls, classifier_module, data, neptune_logger_cls, - neptune_experiment_cls, + neptune_run_object, ): - npt = neptune_logger_cls(neptune_experiment_cls()) - npt.initialize() - assert npt.first_batch_ is True + try: + # neptune-client=0.9.0+ package structure + from neptune.new.attributes.file_set import FileSet + except ImportError: + # neptune-client>=1.0.0 package structure + from neptune.attributes.file_set import FileSet + from skorch.callbacks import Checkpoint + + with tempfile.TemporaryDirectory() as directory: + net = net_cls( + classifier_module, + callbacks=[ + neptune_logger_cls( + run=neptune_run_object, + close_after_train=False, + ), + Checkpoint(dirname=directory), + ], + max_epochs=5, + ) + net.fit(*data) + neptune_run_object['training/model/checkpoint'].upload_files(directory) - net = net_cls( - classifier_module, - callbacks=[npt], - max_epochs=1, + assert neptune_run_object.exists('training/train/epoch/loss') + assert neptune_run_object.exists('training/validation/epoch/loss') + assert neptune_run_object.exists('training/validation/epoch/acc') + + assert neptune_run_object.exists('training/model/checkpoint') + assert isinstance( + neptune_run_object.get_structure()['training']['model']['checkpoint'], + FileSet, ) - npt.on_batch_end(net) - assert npt.first_batch_ is False + neptune_run_object.stop() @pytest.mark.skipif( not sacred_installed, reason='Sacred is not installed')