diff --git a/CHANGES.md b/CHANGES.md index b705165ce..0965c63a6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -7,12 +7,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +### Changed + +### Fixed + +## [0.12.0] - 2022-10-07 + ### Added - Added `load_best` attribute to `EarlyStopping` callback to automatically load module weights of the best result at the end of training - Added a method, `trim_for_prediction`, on the net classes, which trims the net from everything not required for using it for prediction; call this after fitting to reduce the size of the net - Added experimental support for [huggingface accelerate](https://github.com/huggingface/accelerate); use the provided mixin class to add advanced training capabilities provided by the accelerate library to skorch - Add integration for Huggingface tokenizers; use `skorch.hf.HuggingfaceTokenizer` to train a Huggingface tokenizer on your custom data; use `skorch.hf.HuggingfacePretrainedTokenizer` to load a pre-trained Huggingface tokenizer - Added support for creating model checkpoints on Hugging Face Hub using [`HfHubStorage`](https://skorch.readthedocs.io/en/latest/hf.html#skorch.hf.HfHubStorage) +- Added a [notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/CORA-geometric.ipynb) that shows how to use skorch with PyTorch Geometric (#863) ### Changed - The minimum required scikit-learn version has been bumped to 0.22.0 @@ -23,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix a bug in `SliceDataset` that prevented it to be used with `to_numpy` (#858) - Fix a bug that occurred when loading a net that has device set to None (#876) - Fix a bug that in some cases could prevent loading a net that was trained with CUDA without CUDA +- Enable skorch to work on M1/M2 Apple MacBooks (#884) ## [0.11.0] - 2021-10-11 @@ -273,3 +283,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [0.9.0]: https://github.com/skorch-dev/skorch/compare/v0.8.0...v0.9.0 [0.10.0]: https://github.com/skorch-dev/skorch/compare/v0.9.0...v0.10.0 [0.11.0]: https://github.com/skorch-dev/skorch/compare/v0.10.0...v0.11.0 +[0.12.0]: https://github.com/skorch-dev/skorch/compare/v0.11.0...v0.12.0 diff --git a/VERSION b/VERSION index 56a585d4a..ac454c6a1 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.11.1dev +0.12.0 diff --git a/skorch/callbacks/base.py b/skorch/callbacks/base.py index da0ac180d..5462ec5d7 100644 --- a/skorch/callbacks/base.py +++ b/skorch/callbacks/base.py @@ -63,84 +63,3 @@ def get_params(self, deep=True): def set_params(self, **params): BaseEstimator.set_params(self, **params) - - -# TODO: remove after some deprecation period, e.g. skorch 0.12 -def _on_batch_overridden(callback): - """Check if on_batch_begin or on_batch_end were overridden - - If the method does not exist at all, it's not considered overridden. This is - mostly for callbacks that are mocked. - - """ - try: - base_skorch_cls = next(cls for cls in callback.__class__.__mro__ - if cls.__module__.startswith('skorch')) - except StopIteration: - # does not derive from skorch callback, possibly a mock - return False - - obb = base_skorch_cls.on_batch_begin - obe = base_skorch_cls.on_batch_end - return ( - getattr(callback.__class__, 'on_batch_begin', obb) is not obb - or getattr(callback.__class__, 'on_batch_end', obe) is not obe - ) - - -# TODO: remove after some deprecation period, e.g. skorch 0.12 -def _issue_warning_if_on_batch_override(callback_list): - """Check callbacks for overridden on_batch method and issue warning - - We introduced a breaking change by changing the signature of on_batch_begin - and on_batch_end. To help users, we try to detect if they use any custom - callback that overrides on of these methods and issue a warning if they do. - The warning states how to adjust the method signature and how it can be - filtered. - - After some transition period, the checking and the warning should be - removed again. - - Parameters - ---------- - callback_list : list of (str, callback) tuples - List of initialized callbacks. - - Warns - ----- - Issues a ``SkorchWarning`` if any of the callbacks fits the conditions. - - """ - if not callback_list: - return - - callbacks = [callback for _, callback in callback_list] - - # first detect if there are any user defined callbacks - user_defined_callbacks = [ - callback for callback in callbacks - if not callback.__module__.startswith('skorch') - ] - if not user_defined_callbacks: - return - - # check if any of these callbacks overrides on_batch_begin or on_batch_end - overriding_callbacks = [ - callback for callback in user_defined_callbacks - if _on_batch_overridden(callback) - ] - - if not overriding_callbacks: - return - - warning_msg = ( - "You are using an callback that overrides on_batch_begin " - "or on_batch_end. As of skorch 0.10, the signature was changed " - "from 'on_batch_{begin,end}(self, X, y, ...)' to " - "'on_batch_{begin,end}(self, batch, ...)'. To recover, change " - "the signature accordingly and add 'X, y = batch' on the first " - "line of the method body. To suppress this warning, add:\n" - "'import warnings; from skorch.exceptions import SkorchWarning\n" - "warnings.filterwarnings('ignore', category=SkorchWarning)'.") - - warnings.warn(warning_msg, SkorchWarning) diff --git a/skorch/net.py b/skorch/net.py index 1157314a0..def9891a3 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -23,7 +23,6 @@ from skorch.callbacks import EpochTimer from skorch.callbacks import PrintLog from skorch.callbacks import PassthroughScoring -from skorch.callbacks.base import _issue_warning_if_on_batch_override from skorch.dataset import Dataset from skorch.dataset import ValidSplit from skorch.dataset import get_len @@ -352,10 +351,6 @@ def notify(self, method_name, **cb_kwargs): * on_batch_end """ - # TODO: remove after some deprecation period, e.g. skorch 0.12 - if not self.history: # perform check only at the start - _issue_warning_if_on_batch_override(self.callbacks_) - getattr(self, method_name)(self, **cb_kwargs) for _, cb in self.callbacks_: getattr(cb, method_name)(self, **cb_kwargs) diff --git a/skorch/tests/callbacks/test_base.py b/skorch/tests/callbacks/test_base.py deleted file mode 100644 index 5c1085807..000000000 --- a/skorch/tests/callbacks/test_base.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Tests for callbacks/base.py""" - -import warnings - -import pytest - - -# TODO: remove after some deprecation period, e.g. skorch 0.12 -class TestIssueWarningIfOnBatchOverride: - @pytest.fixture - def net(self, classifier_module): - from skorch import NeuralNetClassifier - return NeuralNetClassifier(classifier_module, max_epochs=1) - - @pytest.fixture(scope='module') - def data(self, classifier_data): - return classifier_data - - @pytest.fixture(scope='module') - def callback_cls(self): - from skorch.callbacks import Callback - return Callback - - @pytest.fixture(scope='module') - def skorch_warning(self): - from skorch.exceptions import SkorchWarning - return SkorchWarning - - def test_no_warning_with_default_callbacks(self, net, data, recwarn): - from skorch.callbacks import EpochScoring - net.set_params(callbacks=[('f1', EpochScoring('f1'))]) - net.fit(*data) - - assert not recwarn.list - - def test_no_warning_if_on_batch_not_overridden( - self, net, data, callback_cls, recwarn): - class MyCallback(callback_cls): - def on_epoch_end(self, *args, **kwargs): - pass - - net.set_params(callbacks=[('cb', MyCallback())]) - net.fit(*data) - - assert not recwarn.list - - def test_warning_if_on_batch_begin_overridden( - self, net, data, callback_cls, skorch_warning): - class MyCallback(callback_cls): - def on_batch_begin(self, *args, **kwargs): - pass - - net.set_params(callbacks=[('cb', MyCallback())]) - with pytest.warns(skorch_warning): - net.fit(*data) - - def test_warning_if_on_batch_end_overridden( - self, net, data, callback_cls, skorch_warning): - class MyCallback(callback_cls): - def on_batch_end(self, *args, **kwargs): - pass - - net.set_params(callbacks=[('cb', MyCallback())]) - with pytest.warns(skorch_warning): - net.fit(*data) - - def test_warning_if_on_batch_begin_and_end_overridden( - self, net, data, callback_cls, skorch_warning): - class MyCallback(callback_cls): - def on_batch_begin(self, *args, **kwargs): - pass - def on_batch_end(self, *args, **kwargs): - pass - - net.set_params(callbacks=[('cb', MyCallback())]) - with pytest.warns(skorch_warning): - net.fit(*data) - - def test_no_warning_if_not_derived_from_base_and_no_override( - self, net, data, recwarn): - from skorch.callbacks import EpochScoring - - class MyCallback(EpochScoring): - pass - - net.set_params(callbacks=[('f1', MyCallback('f1'))]) - net.fit(*data) - - assert not recwarn.list - - def test_warning_if_not_derived_from_base_and_override( - self, net, data, skorch_warning): - from skorch.callbacks import EpochScoring - - class MyCallback(EpochScoring): - def on_batch_begin(self, *args, **kwargs): - pass - - net.set_params(callbacks=[('f1', MyCallback('f1'))]) - - with pytest.warns(skorch_warning): - net.fit(*data) - - def test_no_warning_if_filtered( - self, net, data, callback_cls, skorch_warning, recwarn): - warnings.filterwarnings('ignore', category=skorch_warning) - - class MyCallback(callback_cls): - def on_batch_begin(self, *args, **kwargs): - pass - - net.set_params(callbacks=[('cb', MyCallback())]) - net.fit(*data) - - assert not recwarn.list