diff --git a/keras_core/backend/torch/core.py b/keras_core/backend/torch/core.py index d2ced0f51..99a942b54 100644 --- a/keras_core/backend/torch/core.py +++ b/keras_core/backend/torch/core.py @@ -1,4 +1,5 @@ import contextlib +import os import numpy as np import torch @@ -12,7 +13,18 @@ from keras_core.utils.nest import pack_sequence_as DYNAMIC_SHAPES_OK = True -DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +# Some operators such as 'aten::_foreach_mul_.Scalar' +# are not currently implemented for the MPS device. +# check https://github.com/pytorch/pytorch/issues/77764. +if ( + torch.backends.mps.is_available() + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") == "1" +): + DEFAULT_DEVICE = "mps" +elif torch.cuda.is_available(): + DEFAULT_DEVICE = "cuda" +else: + DEFAULT_DEVICE = "cpu" TORCH_DTYPES = { "float16": torch.float16, @@ -145,7 +157,7 @@ def transform(x): if x.requires_grad: x = x.detach() # Tensor has to be moved to CPU before converting to numpy. - if x.is_cuda: + if x.is_cuda or x.is_mps: x = x.cpu() return np.array(x)