From 9847ea4fa4264da4b3ef92e2b4c445accda90b28 Mon Sep 17 00:00:00 2001 From: Alan deLevie Date: Mon, 8 Aug 2022 16:18:14 -0400 Subject: [PATCH] add mps support to utils.to_numpy --- skorch/tests/test_utils.py | 10 ++++++++++ skorch/utils.py | 3 +++ 2 files changed, 13 insertions(+) diff --git a/skorch/tests/test_utils.py b/skorch/tests/test_utils.py index 82783c06d..94b5d4b9b 100644 --- a/skorch/tests/test_utils.py +++ b/skorch/tests/test_utils.py @@ -199,6 +199,16 @@ def test_invalid_inputs(self, to_numpy, x_invalid): expected = "Cannot convert this data type to a numpy array." assert e.value.args[0] == expected + @pytest.mark.skipif( + not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()), + reason='Skipped because mps is not available as a torch backend' + ) + def test_mps_support(self, to_numpy, x_tensor): + device = torch.device('mps') + x_tensor.to(device) + x_numpy = to_numpy(x_tensor) + self.compare_array_to_tensor(x_numpy, x_tensor) + class TestToDevice: @pytest.fixture diff --git a/skorch/utils.py b/skorch/utils.py index 5fa570926..fa6c1d76e 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -165,6 +165,9 @@ def to_numpy(X): if X.is_cuda: X = X.cpu() + if hasattr(X, 'is_mps') and X.is_mps: + X = X.cpu() + if X.requires_grad: X = X.detach()