diff --git a/CHANGES.md b/CHANGES.md index 5bb3b40b0..c195dc29f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -14,8 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Initialize data loaders for training and validation dataset once per fit call instead of once per epoch ([migration guide](https://skorch.readthedocs.io/en/stable/user/FAQ.html#migration-from-0-11-to-0-12)) +- It is now possible to call `np.asarray` with `SliceDataset`s (#858) ### Fixed +- Fix a bug in `SliceDataset` that prevented it to be used with `to_numpy` (#858) ## [0.11.0] - 2021-10-11 diff --git a/skorch/helper.py b/skorch/helper.py index 872c6d9d4..9f92c3ba8 100644 --- a/skorch/helper.py +++ b/skorch/helper.py @@ -14,6 +14,7 @@ from skorch.cli import parse_args # pylint: disable=unused-import from skorch.dataset import unpack_data from skorch.utils import _make_split +from skorch.utils import to_numpy from skorch.utils import is_torch_data_type from skorch.utils import to_tensor @@ -246,6 +247,14 @@ def __getitem__(self, i): return SliceDataset(self.dataset, idx=self.idx, indices=self.indices_[i]) + def __array__(self, dtype=None): + # This method is invoked when calling np.asarray(X) + # https://numpy.org/devdocs/user/basics.dispatch.html + X = [self[i] for i in range(len(self))] + if np.isscalar(X[0]): + return np.asarray(X) + return np.asarray([to_numpy(x) for x in X], dtype=dtype) + def predefined_split(dataset): """Uses ``dataset`` for validiation in :class:`.NeuralNet`. diff --git a/skorch/tests/callbacks/test_all.py b/skorch/tests/callbacks/test_all.py index 856fc2a36..2b946fc03 100644 --- a/skorch/tests/callbacks/test_all.py +++ b/skorch/tests/callbacks/test_all.py @@ -46,12 +46,14 @@ def test_on_x_methods_have_kwargs(self, callbacks, on_x_methods): def test_set_params_with_unknown_key_raises(self, base_cls): with pytest.raises(ValueError) as exc: base_cls().set_params(foo=123) - - msg_start = ( - "Invalid parameter foo for estimator