Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update for neptune 1.0 #934

Merged
merged 9 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ future>=0.17.1
gpytorch>=1.5
jupyter
matplotlib>=2.0.2
neptune-client>=0.14.3
neptune
numpydoc
openpyxl
pandas
Expand Down
29 changes: 19 additions & 10 deletions skorch/callbacks/logging.py
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,16 @@ class NeptuneLogger(Callback):
$ python -m pip install psutil

You can view example experiment logs here:
https://app.neptune.ai/shared/skorch-integration/e/SKOR-23/
https://app.neptune.ai/o/common/org/skorch-integration/e/SKOR-32/all

Examples
--------
$ # Install Neptune
$ python -m pip install neptune-client
$ python -m pip install neptune

>>> # Create a Neptune run
>>> import neptune.new as neptune
>>> from neptune.new.types import File
>>> import neptune
>>> from neptune.types import File
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the example and found an issue which was already there before but should still be fixed:

- neptune_logger.run["checkpoints].upload_files("./checkpoints")
+ neptune_logger.run["checkpoints"].upload_files("./checkpoints") 

Also, a few lines that start with ... should be >>>.

...
... # This example uses the API token for anonymous users.
... # For your own projects, use the token associated with your neptune.ai account.
Expand Down Expand Up @@ -137,8 +137,8 @@ class NeptuneLogger(Callback):

Parameters
----------
run : neptune.new.Run
Instantiated ``Run`` class.
run : neptune.Run or neptune.handler.Handler
Instantiated ``Run`` or ``Handler`` class.

log_on_batch_end : bool (default=False)
Whether to log loss and other metrics on batch level.
Expand Down Expand Up @@ -231,18 +231,27 @@ def on_epoch_end(self, net, **kwargs):

def on_train_end(self, net, **kwargs):
try:
self._metric_logger['train/epoch/event_lr'].log(net.history[:, 'event_lr'])
self._metric_logger['train/epoch/event_lr'].append(net.history[:, 'event_lr'])
except KeyError:
pass
if self.close_after_train:
self.run.stop()
try: # >1.0 package structure
from neptune.handler import Handler
except ImportError: # <1.0 package structure
from neptune.new.handler import Handler
Comment on lines +235 to +238
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be just

from neptune.handler import Handler

given that the dependency is v >=1.0.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actual logging code which should be compatible with both pre and post 1.0.

Other update where we assume 1.0 structure is in test (where we have updated the requirements to install latest version).

Copy link
Contributor

@twolodzko twolodzko Mar 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the requirements say neptune hence >=1.0.0, so there is no backward compatibility assumed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the requirements for development. Installing skorch as a user doesn't install neptune. So we could have users who may use older neptune-client.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is all the code compatible with older versions? E.g. the use of append instead of log. If not, there is no point in having this check.

I think it's fine to tell users to install the latest version of neptune (except if you know that users are reluctant to upgrade versions). E.g. when this class is initialized, there could be a version check and a helpful error message when the version is too low. WDYT?


root_obj = self.run
if isinstance(self.run, Handler):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a comment here? What is the Handler and why do we sometimes get a Handler and sometimes a Run?

root_obj = self.run.get_root_object()

root_obj.stop()

def _log_metric(self, name, logs, batch):
kind, _, key = name.partition('_')

if not key:
key = 'epoch_duration' if kind == 'dur' else kind
self._metric_logger[key].log(logs[name])
self._metric_logger[key].append(logs[name])
else:
if kind == 'valid':
kind = 'validation'
Expand All @@ -253,7 +262,7 @@ def _log_metric(self, name, logs, batch):
granularity = 'epoch'

# for example: train / epoch / loss
self._metric_logger[kind][granularity][key].log(logs[name])
self._metric_logger[kind][granularity][key].append(logs[name])

@staticmethod
def _model_summary_file(model):
Expand Down
39 changes: 26 additions & 13 deletions skorch/tests/callbacks/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ def neptune_logger_cls(self):

@pytest.fixture
def neptune_run_object(self):
try:
# neptune-client=0.9.0+ package structure
import neptune.new as neptune
except ImportError:
# neptune-client>=1.0.0 package structure
import neptune
import neptune
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved

run = neptune.init_run(
project="tests/dry-run",
Expand Down Expand Up @@ -149,6 +144,29 @@ def test_fit_with_real_experiment(
# Checkpoint callback was not used
assert not neptune_run_object.exists('training/model/checkpoint')

def test_fit_with_handler(
self,
net_cls,
classifier_module,
data,
neptune_logger_cls,
neptune_run_object,
):
net = net_cls(
classifier_module,
callbacks=[neptune_logger_cls(neptune_run_object['my_namespace'])],
max_epochs=5,
)
net.fit(*data)

assert neptune_run_object.exists('my_namespace/training/epoch_duration')
assert neptune_run_object.exists('my_namespace/training/train/epoch/loss')
assert neptune_run_object.exists('my_namespace/training/validation/epoch/loss')
assert neptune_run_object.exists('my_namespace/training/validation/epoch/acc')

# Checkpoint callback was not used
assert not neptune_run_object.exists('my_namespace/training/model/checkpoint')

def test_log_on_batch_level_on(
self,
net_cls,
Expand All @@ -168,7 +186,7 @@ def test_log_on_batch_level_on(

# (5 epochs x (40/4 batches x 2 batch metrics + 2 epoch metrics) = 110 calls) + base metrics
assert mock_experiment.__getitem__.call_count == 110 + self.NUM_BASE_METRICS
mock_experiment['training']['train']['batch']['batch_size'].log.assert_any_call(4)
mock_experiment['training']['train']['batch']['batch_size'].append.assert_any_call(4)

def test_log_on_batch_level_off(
self,
Expand Down Expand Up @@ -206,12 +224,7 @@ def test_fit_with_real_experiment_saving_checkpoints(
neptune_logger_cls,
neptune_run_object,
):
try:
# neptune-client=0.9.0+ package structure
from neptune.new.attributes.file_set import FileSet
except ImportError:
# neptune-client>=1.0.0 package structure
from neptune.attributes.file_set import FileSet
from neptune.attributes.file_set import FileSet
from skorch.callbacks import Checkpoint

with tempfile.TemporaryDirectory() as directory:
Expand Down