Skip to content

Commit

Permalink
CI Add PyTorch 2.5.1 and Python 3.12 (#1069)
Browse files Browse the repository at this point in the history
Changes:

- Drop 2.1.2, update 2.4.0 to 2.4.1
- Drop Python 3.8, add Python 3.12
- Filter new torch warning about weights_only as that trips up some
  warning tests
- Unpin protobuf from dev requirements
- Filter warnings caused by usage of utcnow by protobuf
- Skip torch.compile tests when torch < 2.4 and Python 3.12 are used, as
  this is not supported by torch

Note that I had to upgrade torch and Python at the same time because
torch 2.5 requires Python > 3.8 and because Python 3.12 requires torch >
2.1 (which is dropped in favor of 2.5).
  • Loading branch information
BenjaminBossan authored Nov 5, 2024
1 parent 9ff9cfa commit ad0259b
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ jobs:
strategy:
fail-fast: false # don't cancel all jobs when one fails
matrix:
python_version: ['3.8', '3.9', '3.10', '3.11']
torch_version: ['2.1.2+cpu', '2.2.2+cpu', '2.3.1+cpu', '2.4.0+cpu']
python_version: ['3.9', '3.10', '3.11', '3.12']
torch_version: ['2.2.2+cpu', '2.3.1+cpu', '2.4.1+cpu', '2.5.1+cpu']
os: [ubuntu-latest]

steps:
Expand Down
6 changes: 3 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ skorch also provides many convenient features, among others:
Installation
============

skorch requires Python 3.8 or higher.
skorch requires Python 3.9 or higher.

conda installation
==================
Expand Down Expand Up @@ -244,10 +244,10 @@ instructions for PyTorch, visit the `PyTorch website
<http://pytorch.org/>`__. skorch officially supports the last four
minor PyTorch versions, which currently are:

- 2.1.2
- 2.2.2
- 2.3.1
- 2.4.0
- 2.4.1
- 2.5.1

However, that doesn't mean that older versions don't work, just that
they aren't tested. Since skorch mostly relies on the stable part of
Expand Down
6 changes: 3 additions & 3 deletions docs/user/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ If you want to help developing, run:
pylint skorch # static code checks
You may adjust the Python version to any of the supported Python versions, i.e.
Python 3.8 or higher.
Python 3.9 or higher.

Using pip
^^^^^^^^^
Expand Down Expand Up @@ -98,10 +98,10 @@ instructions for PyTorch, visit the `PyTorch website
<http://pytorch.org/>`__. skorch officially supports the last four
minor PyTorch versions, which currently are:

- 2.1.2
- 2.2.2
- 2.3.1
- 2.4.0
- 2.4.1
- 2.5.1

However, that doesn't mean that older versions don't work, just that
they aren't tested. Since skorch mostly relies on the stable part of
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ numpydoc
openpyxl
pandas
pillow
protobuf>=3.12.0,<4.0dev
protobuf
pylint
pytest>=3.4
pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
tests_require = [l.strip() for l in f]


python_requires = '>=3.8'
python_requires = '>=3.9'

docs_require = [
'Sphinx',
Expand Down
28 changes: 24 additions & 4 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def test_pickle_save_and_load_mixed_devices(
with open(str(p), 'rb') as f:
if not expect_warning:
m = pickle.load(f)
assert not recwarn.list
assert not any(w.category == DeviceWarning for w in recwarn.list)
else:
with pytest.warns(DeviceWarning) as w:
m = pickle.load(f)
Expand All @@ -495,11 +495,17 @@ def test_pickle_save_and_load_mixed_devices(
# We should have captured two warnings:
# 1. one for the failed load
# 2. for switching devices on the net instance
assert len(w.list) == 2
assert w.list[0].message.args[0] == (
# remove possible future warning about weights_only=False
# TODO: remove filter when torch<=2.4 is dropped
w_list = [
warning for warning in w.list
if "weights_only=False" not in warning.message.args[0]
]
assert len(w_list) == 2
assert w_list[0].message.args[0] == (
'Requested to load data to CUDA but no CUDA devices '
'are available. Loading on device "cpu" instead.')
assert w.list[1].message.args[0] == (
assert w_list[1].message.args[0] == (
'Setting self.device = {} since the requested device ({}) '
'is not available.'.format(load_dev, save_dev))

Expand Down Expand Up @@ -4254,6 +4260,13 @@ def test_fit_and_predict_with_compile(self, net_cls, module_cls, data):
if not hasattr(torch, 'compile'):
pytest.skip(reason="torch.compile not available")

# python 3.12 requires torch >= 2.4 to support compile
# TODO: remove once we remove support for torch < 2.4
from skorch._version import Version

if Version(torch.__version__) < Version('2.4.0') and sys.version_info >= (3, 12):
pytest.skip(reason="When using Python 3.12, torch.compile requires torch >= 2.4")

# use real torch.compile, not mocked, can be a bit slow
X, y = data
net = net_cls(module_cls, max_epochs=1, compile=True).initialize()
Expand All @@ -4274,6 +4287,13 @@ def test_binary_classifier_with_compile(self, data):
# because of a failing isinstance check
from skorch import NeuralNetBinaryClassifier

# python 3.12 requires torch >= 2.4 to support compile
# TODO: remove once we remove support for torch < 2.4
from skorch._version import Version

if Version(torch.__version__) < Version('2.4.0') and sys.version_info >= (3, 12):
pytest.skip(reason="When using Python 3.12, torch.compile requires torch >= 2.4")

X, y = data[0], data[1].astype(np.float32)

class MyNet(nn.Module):
Expand Down
19 changes: 14 additions & 5 deletions skorch/tests/test_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,16 @@ def test_dimension_mismatch_warning(self, net_cls, module_cls, data, recwarn):
X, y = X[:100], y[:100].flatten() # make y 1d
net.fit(X, y)

w0, w1 = recwarn.list # one warning for train, one for valid
# The warning comes from PyTorch, so checking the exact wording is prone to
# error in future PyTorch versions. We thus check a substring of the
# whole message and cross our fingers that it's not changed.
msg_substr = (
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size"
)
assert msg_substr in str(w0.message)
assert msg_substr in str(w1.message)
warn_list = [w for w in recwarn.list if msg_substr in str(w.message)]
# one warning for train, one for valid
assert len(warn_list) == 2

def test_fitting_with_1d_target_and_pred(
self, net_cls, module_cls, data, module_pred_1d_cls, recwarn
Expand All @@ -159,7 +159,11 @@ def test_fitting_with_1d_target_and_pred(

net = net_cls(module_pred_1d_cls)
net.fit(X, y)
assert not recwarn.list
msg_substr = (
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size"
)
assert not any(msg_substr in str(w.message) for w in recwarn.list)

def test_bagging_regressor(
self, net_cls, module_cls, data, module_pred_1d_cls, recwarn
Expand All @@ -173,4 +177,9 @@ def test_bagging_regressor(
y = y.flatten() # make y 1d or else sklearn will complain
regr = BaggingRegressor(net, n_estimators=2, random_state=0)
regr.fit(X, y) # does not raise
assert not recwarn.list # ensure there is no broadcast warning from torch
# ensure there is no broadcast warning from torch
msg_substr = (
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size"
)
assert not any(msg_substr in str(w.message) for w in recwarn.list)

0 comments on commit ad0259b

Please sign in to comment.