diff --git a/.github/workflows/build-cmake.yml b/.github/workflows/build-cmake.yml index 1dce7b8446a..1ab20669f0b 100644 --- a/.github/workflows/build-cmake.yml +++ b/.github/workflows/build-cmake.yml @@ -54,7 +54,7 @@ jobs: export GPU_ARCH_TYPE=cpu export GPU_ARCH_VERSION='' - ./.github/scripts/cmake.sh + ${CONDA_RUN} ./.github/scripts/cmake.sh windows: strategy: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ad327129912..533a44a9e84 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -68,7 +68,7 @@ jobs: export GPU_ARCH_TYPE=cpu export GPU_ARCH_VERSION='' - ./.github/scripts/unittest.sh + ${CONDA_RUN} ./.github/scripts/unittest.sh unittests-windows: strategy: diff --git a/setup.py b/setup.py index 753a50ffeed..fedbc370f72 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,6 @@ import shutil import subprocess import sys -import warnings import torch from pkg_resources import DistributionNotFound, get_distribution, parse_version @@ -139,6 +138,7 @@ def get_extensions(): + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) ) + source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm")) print("Compiling extensions with following flags:") force_cuda = os.getenv("FORCE_CUDA", "0") == "1" @@ -204,15 +204,8 @@ def get_extensions(): define_macros += [("WITH_HIP", None)] nvcc_flags = [] extra_compile_args["nvcc"] = nvcc_flags - - # FIXME: MPS build breaks custom ops registration, so it was disabled. - # See https://github.com/pytorch/vision/issues/8456. - # TODO: Fix MPS build, remove warning below, and put back commented-out elif block.V - if force_mps: - warnings.warn("MPS build is temporarily disabled!!!!") - # elif torch.backends.mps.is_available() or force_mps: - # source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm")) - # sources += source_mps + elif torch.backends.mps.is_available() or force_mps: + sources += source_mps if sys.platform == "win32": define_macros += [("torchvision_EXPORTS", None)] diff --git a/test/conftest.py b/test/conftest.py index 89b4946e612..a9768598ded 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -49,8 +49,7 @@ def pytest_collection_modifyitems(items): # There are special cases though, see below item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)) - # TODO: uncoment when MPS works again - see FIXME in setup.py - if needs_mps: # and not torch.backends.mps.is_available(): + if needs_mps and not torch.backends.mps.is_available(): item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG)) if IN_FBCODE: