diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml
index b93ef14d..2bd0730c 100644
--- a/.github/workflows/testing.yml
+++ b/.github/workflows/testing.yml
@@ -20,8 +20,8 @@ jobs:
strategy:
fail-fast: false # don't cancel all jobs when one fails
matrix:
- python_version: ['3.8', '3.9', '3.10', '3.11']
- torch_version: ['2.1.2+cpu', '2.2.2+cpu', '2.3.1+cpu', '2.4.0+cpu']
+ python_version: ['3.9', '3.10', '3.11', '3.12']
+ torch_version: ['2.2.2+cpu', '2.3.1+cpu', '2.4.1+cpu', '2.5.1+cpu']
os: [ubuntu-latest]
steps:
diff --git a/README.rst b/README.rst
index a68e0f0a..e2af81ab 100644
--- a/README.rst
+++ b/README.rst
@@ -135,7 +135,7 @@ skorch also provides many convenient features, among others:
Installation
============
-skorch requires Python 3.8 or higher.
+skorch requires Python 3.9 or higher.
conda installation
==================
@@ -244,10 +244,10 @@ instructions for PyTorch, visit the `PyTorch website
`__. skorch officially supports the last four
minor PyTorch versions, which currently are:
-- 2.1.2
- 2.2.2
- 2.3.1
-- 2.4.0
+- 2.4.1
+- 2.5.1
However, that doesn't mean that older versions don't work, just that
they aren't tested. Since skorch mostly relies on the stable part of
diff --git a/docs/user/installation.rst b/docs/user/installation.rst
index 6a369339..5f424048 100644
--- a/docs/user/installation.rst
+++ b/docs/user/installation.rst
@@ -58,7 +58,7 @@ If you want to help developing, run:
pylint skorch # static code checks
You may adjust the Python version to any of the supported Python versions, i.e.
-Python 3.8 or higher.
+Python 3.9 or higher.
Using pip
^^^^^^^^^
@@ -98,10 +98,10 @@ instructions for PyTorch, visit the `PyTorch website
`__. skorch officially supports the last four
minor PyTorch versions, which currently are:
-- 2.1.2
- 2.2.2
- 2.3.1
-- 2.4.0
+- 2.4.1
+- 2.5.1
However, that doesn't mean that older versions don't work, just that
they aren't tested. Since skorch mostly relies on the stable part of
diff --git a/requirements-dev.txt b/requirements-dev.txt
index f02124df..4202abcf 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -10,7 +10,7 @@ numpydoc
openpyxl
pandas
pillow
-protobuf>=3.12.0,<4.0dev
+protobuf
pylint
pytest>=3.4
pytest-cov
diff --git a/setup.py b/setup.py
index 2de0f3b8..f4e40f40 100644
--- a/setup.py
+++ b/setup.py
@@ -13,7 +13,7 @@
tests_require = [l.strip() for l in f]
-python_requires = '>=3.8'
+python_requires = '>=3.9'
docs_require = [
'Sphinx',
diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py
index 768ef00b..53c07ab0 100644
--- a/skorch/tests/test_net.py
+++ b/skorch/tests/test_net.py
@@ -484,7 +484,7 @@ def test_pickle_save_and_load_mixed_devices(
with open(str(p), 'rb') as f:
if not expect_warning:
m = pickle.load(f)
- assert not recwarn.list
+ assert not any(w.category == DeviceWarning for w in recwarn.list)
else:
with pytest.warns(DeviceWarning) as w:
m = pickle.load(f)
@@ -495,11 +495,17 @@ def test_pickle_save_and_load_mixed_devices(
# We should have captured two warnings:
# 1. one for the failed load
# 2. for switching devices on the net instance
- assert len(w.list) == 2
- assert w.list[0].message.args[0] == (
+ # remove possible future warning about weights_only=False
+ # TODO: remove filter when torch<=2.4 is dropped
+ w_list = [
+ warning for warning in w.list
+ if "weights_only=False" not in warning.message.args[0]
+ ]
+ assert len(w_list) == 2
+ assert w_list[0].message.args[0] == (
'Requested to load data to CUDA but no CUDA devices '
'are available. Loading on device "cpu" instead.')
- assert w.list[1].message.args[0] == (
+ assert w_list[1].message.args[0] == (
'Setting self.device = {} since the requested device ({}) '
'is not available.'.format(load_dev, save_dev))
@@ -4254,6 +4260,13 @@ def test_fit_and_predict_with_compile(self, net_cls, module_cls, data):
if not hasattr(torch, 'compile'):
pytest.skip(reason="torch.compile not available")
+ # python 3.12 requires torch >= 2.4 to support compile
+ # TODO: remove once we remove support for torch < 2.4
+ from skorch._version import Version
+
+ if Version(torch.__version__) < Version('2.4.0') and sys.version_info >= (3, 12):
+ pytest.skip(reason="When using Python 3.12, torch.compile requires torch >= 2.4")
+
# use real torch.compile, not mocked, can be a bit slow
X, y = data
net = net_cls(module_cls, max_epochs=1, compile=True).initialize()
@@ -4274,6 +4287,13 @@ def test_binary_classifier_with_compile(self, data):
# because of a failing isinstance check
from skorch import NeuralNetBinaryClassifier
+ # python 3.12 requires torch >= 2.4 to support compile
+ # TODO: remove once we remove support for torch < 2.4
+ from skorch._version import Version
+
+ if Version(torch.__version__) < Version('2.4.0') and sys.version_info >= (3, 12):
+ pytest.skip(reason="When using Python 3.12, torch.compile requires torch >= 2.4")
+
X, y = data[0], data[1].astype(np.float32)
class MyNet(nn.Module):
diff --git a/skorch/tests/test_regressor.py b/skorch/tests/test_regressor.py
index 6d414543..2c4da7ea 100644
--- a/skorch/tests/test_regressor.py
+++ b/skorch/tests/test_regressor.py
@@ -134,7 +134,6 @@ def test_dimension_mismatch_warning(self, net_cls, module_cls, data, recwarn):
X, y = X[:100], y[:100].flatten() # make y 1d
net.fit(X, y)
- w0, w1 = recwarn.list # one warning for train, one for valid
# The warning comes from PyTorch, so checking the exact wording is prone to
# error in future PyTorch versions. We thus check a substring of the
# whole message and cross our fingers that it's not changed.
@@ -142,8 +141,9 @@ def test_dimension_mismatch_warning(self, net_cls, module_cls, data, recwarn):
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size"
)
- assert msg_substr in str(w0.message)
- assert msg_substr in str(w1.message)
+ warn_list = [w for w in recwarn.list if msg_substr in str(w.message)]
+ # one warning for train, one for valid
+ assert len(warn_list) == 2
def test_fitting_with_1d_target_and_pred(
self, net_cls, module_cls, data, module_pred_1d_cls, recwarn
@@ -159,7 +159,11 @@ def test_fitting_with_1d_target_and_pred(
net = net_cls(module_pred_1d_cls)
net.fit(X, y)
- assert not recwarn.list
+ msg_substr = (
+ "This will likely lead to incorrect results due to broadcasting. "
+ "Please ensure they have the same size"
+ )
+ assert not any(msg_substr in str(w.message) for w in recwarn.list)
def test_bagging_regressor(
self, net_cls, module_cls, data, module_pred_1d_cls, recwarn
@@ -173,4 +177,9 @@ def test_bagging_regressor(
y = y.flatten() # make y 1d or else sklearn will complain
regr = BaggingRegressor(net, n_estimators=2, random_state=0)
regr.fit(X, y) # does not raise
- assert not recwarn.list # ensure there is no broadcast warning from torch
+ # ensure there is no broadcast warning from torch
+ msg_substr = (
+ "This will likely lead to incorrect results due to broadcasting. "
+ "Please ensure they have the same size"
+ )
+ assert not any(msg_substr in str(w.message) for w in recwarn.list)