Skip to content

Commit

Permalink
Re-enable vision MPS builds (#8485)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
huydhn and NicolasHug authored Jun 10, 2024
1 parent f1bcbd3 commit f96c42f
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-cmake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
export GPU_ARCH_TYPE=cpu
export GPU_ARCH_VERSION=''
./.github/scripts/cmake.sh
${CONDA_RUN} ./.github/scripts/cmake.sh
windows:
strategy:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
export GPU_ARCH_TYPE=cpu
export GPU_ARCH_VERSION=''
./.github/scripts/unittest.sh
${CONDA_RUN} ./.github/scripts/unittest.sh
unittests-windows:
strategy:
Expand Down
13 changes: 3 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import shutil
import subprocess
import sys
import warnings

import torch
from pkg_resources import DistributionNotFound, get_distribution, parse_version
Expand Down Expand Up @@ -139,6 +138,7 @@ def get_extensions():
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
)
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))

print("Compiling extensions with following flags:")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
Expand Down Expand Up @@ -204,15 +204,8 @@ def get_extensions():
define_macros += [("WITH_HIP", None)]
nvcc_flags = []
extra_compile_args["nvcc"] = nvcc_flags

# FIXME: MPS build breaks custom ops registration, so it was disabled.
# See https://github.com/pytorch/vision/issues/8456.
# TODO: Fix MPS build, remove warning below, and put back commented-out elif block.V
if force_mps:
warnings.warn("MPS build is temporarily disabled!!!!")
# elif torch.backends.mps.is_available() or force_mps:
# source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
# sources += source_mps
elif torch.backends.mps.is_available() or force_mps:
sources += source_mps

if sys.platform == "win32":
define_macros += [("torchvision_EXPORTS", None)]
Expand Down
3 changes: 1 addition & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def pytest_collection_modifyitems(items):
# There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))

# TODO: uncoment when MPS works again - see FIXME in setup.py
if needs_mps: # and not torch.backends.mps.is_available():
if needs_mps and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))

if IN_FBCODE:
Expand Down

0 comments on commit f96c42f

Please sign in to comment.