From 5de3ead712814bce33e244fc1d43bab8ab74c6cf Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 17 Feb 2023 18:30:20 +0000 Subject: [PATCH] [MPS] Add optional `minor` argument to `is_macos13_or_newer` (#95065) Will be needed if one wants to make accurate XFAIL validation I.e. `torch.backends.mps.is_macos13_or_newer()` will return True if PyTorch is running on MacOS 13.0 or newer, `torch.backends.mps.is_macos13_or_newer(1)` will return True if running on MacOS 13.1 or newer and `torch.backends.mps.is_macos13_or_newer(2)` will return True if running on MacOS 13.2 or newer Do not use 13.3 check as `@available` does not really work for shared libraries Pull Request resolved: https://github.com/pytorch/pytorch/pull/95065 Approved by: https://github.com/albanD --- aten/src/ATen/detail/MPSHooksInterface.h | 2 +- aten/src/ATen/mps/MPSHooks.cpp | 14 ++++++++++++-- aten/src/ATen/mps/MPSHooks.h | 2 +- torch/_C/__init__.pyi.in | 2 +- torch/backends/mps/__init__.py | 4 ++-- torch/csrc/mps/Module.cpp | 11 ++++++----- 6 files changed, 23 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 27f4f193c63ae..827d441645f12 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -28,7 +28,7 @@ struct TORCH_API MPSHooksInterface { return false; } - virtual bool isOnMacOS13orNewer() const { + virtual bool isOnMacOS13orNewer(unsigned minor = 0) const { AT_ERROR("MPS backend is not available."); } diff --git a/aten/src/ATen/mps/MPSHooks.cpp b/aten/src/ATen/mps/MPSHooks.cpp index e71bfcc73922d..89adac6c34b15 100644 --- a/aten/src/ATen/mps/MPSHooks.cpp +++ b/aten/src/ATen/mps/MPSHooks.cpp @@ -17,8 +17,18 @@ bool MPSHooks::hasMPS() const { return at::mps::is_available(); } -bool MPSHooks::isOnMacOS13orNewer() const { - return at::mps::is_macos_13_or_newer(); +bool MPSHooks::isOnMacOS13orNewer(unsigned minor) const { + switch (minor) { + case 0: + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS); + case 1: + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS); + case 2: + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); + default: + TORCH_WARN("Can't check whether running on 13.",minor,"+ returning one for 13.2+"); + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); + } } Allocator* MPSHooks::getMPSDeviceAllocator() const { diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index 260113891d51d..9e913b38a2e10 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -13,7 +13,7 @@ struct MPSHooks : public at::MPSHooksInterface { MPSHooks(at::MPSHooksArgs) {} void initMPS() const override; bool hasMPS() const override; - bool isOnMacOS13orNewer() const override; + bool isOnMacOS13orNewer(unsigned minor) const override; Allocator* getMPSDeviceAllocator() const override; const Generator& getDefaultMPSGenerator() const override; void deviceSynchronize() const override; diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1bd547cc3c6b2..b4f8510f6fc6c 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1207,7 +1207,7 @@ def _mps_setMemoryFraction(fraction: _float) -> None: ... def _mps_currentAllocatedMemory() -> _int: ... def _mps_driverAllocatedMemory() -> _int: ... def _mps_is_available() -> _bool: ... -def _mps_is_on_macos_13_or_newer() -> _bool: ... +def _mps_is_on_macos_13_or_newer(minor: _int) -> _bool: ... # Defined in torch/csrc/cuda/Module.cpp def _cuda_getCurrentStream(device: _int) -> Tuple: ... diff --git a/torch/backends/mps/__init__.py b/torch/backends/mps/__init__.py index 32f284f1d5003..2c6ef64665bc8 100644 --- a/torch/backends/mps/__init__.py +++ b/torch/backends/mps/__init__.py @@ -19,9 +19,9 @@ def is_available() -> bool: @_lru_cache() -def is_macos13_or_newer() -> bool: +def is_macos13_or_newer(minor: int = 0) -> bool: r"""Returns a bool indicating whether MPS is running on MacOS 13 or newer.""" - return torch._C._mps_is_on_macos_13_or_newer() + return torch._C._mps_is_on_macos_13_or_newer(minor) # Register prims as implementation of var_mean and group_norm diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index ffbc3b9eceaaf..0a1c45c0838d1 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -59,11 +59,12 @@ static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) { END_HANDLE_TH_ERRORS } -static PyObject* MPSModule_isMacOS13orNewer( - PyObject* _unused, - PyObject* noargs) { +static PyObject* MPSModule_isMacOS13orNewer(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS - if (at::detail::getMPSHooks().isOnMacOS13orNewer()) { + THPUtils_assert( + THPUtils_checkLong(args), "invalid argument to isOnMacOS13orNewer()"); + auto minor = THPUtils_unpackUInt32(args); + if (at::detail::getMPSHooks().isOnMacOS13orNewer(minor)) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; @@ -124,7 +125,7 @@ static struct PyMethodDef _MPSModule_methods[] = { {"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr}, {"_mps_is_on_macos_13_or_newer", MPSModule_isMacOS13orNewer, - METH_NOARGS, + METH_O, nullptr}, {"_mps_get_default_generator", MPSModule_getDefaultMPSGenerator,