Skip to content

Commit

Permalink
add mps support to utils.to_numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
adelevie committed Aug 8, 2022
1 parent 64a2d16 commit 9847ea4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
10 changes: 10 additions & 0 deletions skorch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,16 @@ def test_invalid_inputs(self, to_numpy, x_invalid):
expected = "Cannot convert this data type to a numpy array."
assert e.value.args[0] == expected

@pytest.mark.skipif(
not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()),
reason='Skipped because mps is not available as a torch backend'
)
def test_mps_support(self, to_numpy, x_tensor):
device = torch.device('mps')
x_tensor.to(device)
x_numpy = to_numpy(x_tensor)
self.compare_array_to_tensor(x_numpy, x_tensor)


class TestToDevice:
@pytest.fixture
Expand Down
3 changes: 3 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ def to_numpy(X):
if X.is_cuda:
X = X.cpu()

if hasattr(X, 'is_mps') and X.is_mps:
X = X.cpu()

if X.requires_grad:
X = X.detach()

Expand Down

0 comments on commit 9847ea4

Please sign in to comment.