Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MPS Backend [without torch.amp.autocast ] #2993

Merged
merged 7 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions .github/workflows/mps-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
name: Run unit tests on M1
on:
push:
branches:
- master
- "*.*.*"
paths:
- "ignite/**"
- "tests/ignite/**"
- "tests/run_code_style.sh"
- "examples/**.py"
- "requirements-dev.txt"
- ".github/workflows/mps-tests.yml"
pull_request:
paths:
- "ignite/**"
- "tests/ignite/**"
- "tests/run_code_style.sh"
- "examples/**.py"
- "requirements-dev.txt"
- ".github/workflows/mps-tests.yml"
workflow_dispatch:

concurrency:
# <workflow_name>-<branch_name>-<true || commit_sha (if branch is protected)>
group: mps-tests-${{ github.ref_name }}-${{ !(github.ref_protected) || github.sha }}
cancel-in-progress: true

# Cherry-picked from
# - https://github.com/pytorch/vision/main/.github/workflows/tests.yml
# - https://github.com/pytorch/test-infra/blob/main/.github/workflows/macos_job.yml

jobs:
mps-tests:
strategy:
matrix:
python-version:
- "3.8"
pytorch-channel: ["pytorch"]
skip-distrib-tests: 1
fail-fast: false
runs-on: ["macos-m1-12"]
timeout-minutes: 60

steps:
- name: Clean workspace
run: |
echo "::group::Cleanup debug output"
sudo rm -rfv "${GITHUB_WORKSPACE}"
mkdir -p "${GITHUB_WORKSPACE}"
echo "::endgroup::"

- name: Checkout repository (pytorch/test-infra)
uses: actions/checkout@v3
with:
# Support the use case where we need to checkout someone's fork
repository: pytorch/test-infra
path: test-infra

- name: Setup miniconda
uses: ./test-infra/.github/actions/setup-miniconda
with:
python-version: ${{ matrix.python-version }}

- name: Checkout repository (${{ github.repository }})
uses: actions/checkout@v3
with:
# Support the use case where we need to checkout someone's fork
repository: ${{ github.repository }}
ref: ${{ github.ref }}
path: ${{ github.repository }}
fetch-depth: 1

- name: Install PyTorch
if: ${{ matrix.pytorch-channel == 'pytorch' }}
shell: bash -l {0}
continue-on-error: false
run: pip install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cu118

- name: Install PyTorch (nightly)
if: ${{ matrix.pytorch-channel == 'pytorch-nightly' }}
run: pip install torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre

- name: Install dependencies
run: |
pip install -r requirements-dev.txt
pip install -e .
pip list

# Download MNIST: https://github.com/pytorch/ignite/issues/1737
# to "/tmp" for unit tests
- name: Download MNIST
uses: pytorch-ignite/download-mnist-github-action@master
with:
target_dir: /tmp

# Copy MNIST to "." for the examples
- name: Copy MNIST
run: |
cp -R /tmp/MNIST .

- name: Run Tests
run: |
SKIP_DISTRIB_TESTS=${{ matrix.skip-distrib-tests }} bash tests/run_cpu_tests.sh

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ${{ github.repository }}/coverage.xml
flags: mps
fail_ci_if_error: false

- name: Run MNIST Examples
run: python examples/mnist/mnist.py --epochs=1
1 change: 1 addition & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ jobs:
run: |
pip install -r requirements-dev.txt
python setup.py install
pip list

- name: Check code formatting
run: |
Expand Down
2 changes: 2 additions & 0 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@
def device(self) -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")

Check warning on line 329 in ignite/distributed/comp_models/base.py

View check run for this annotation

Codecov / codecov/patch

ignite/distributed/comp_models/base.py#L329

Added line #L329 was not covered by tests
return torch.device("cpu")

def backend(self) -> Optional[str]:
Expand Down
19 changes: 15 additions & 4 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@
Added Gradient Accumulation.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added support for ``mps`` device
"""

if gradient_accumulation_steps <= 0:
Expand Down Expand Up @@ -374,9 +376,12 @@


def _check_arg(
on_tpu: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
) -> Tuple[Optional[str], Optional["torch.cuda.amp.GradScaler"]]:
"""Checking tpu, amp and GradScaler instance combinations."""
"""Checking tpu, mps, amp and GradScaler instance combinations."""
if on_mps and amp_mode:
raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.")

Check warning on line 383 in ignite/engine/__init__.py

View check run for this annotation

Codecov / codecov/patch

ignite/engine/__init__.py#L383

Added line #L383 was not covered by tests

if on_tpu and not idist.has_xla_support:
raise RuntimeError("In order to run on TPU, please install PyTorch XLA")

Expand Down Expand Up @@ -525,11 +530,14 @@
Added Gradient Accumulation argument for all supervised training methods.
.. versionchanged:: 0.4.11
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added support for ``mps`` device
"""

device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _scaler = _check_arg(on_tpu, amp_mode, scaler)
on_mps = "mps" in device_type if device_type is not None else False
mode, _scaler = _check_arg(on_tpu, on_mps, amp_mode, scaler)

if mode == "amp":
_update = supervised_training_step_amp(
Expand Down Expand Up @@ -754,10 +762,13 @@
Added ``amp_mode`` argument for automatic mixed precision.
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added support for ``mps`` device
"""
device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, None)
on_mps = "mps" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None)

metrics = metrics or {}
if mode == "amp":
Expand Down
35 changes: 32 additions & 3 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def _test_create_mocked_supervised_trainer(
data = [(x, y)]

on_tpu = "xla" in trainer_device if trainer_device is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, scaler)
on_mps = "mps" in trainer_device if trainer_device is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, scaler)

if model_device == trainer_device or ((model_device == "cpu") ^ (trainer_device == "cpu")):
trainer.run(data)
Expand Down Expand Up @@ -336,7 +337,8 @@ def _test_create_evaluation_step_amp(

device_type = evaluator_device.type if isinstance(evaluator_device, torch.device) else evaluator_device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, None)
on_mps = "mps" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None)

evaluate_step = supervised_evaluation_step_amp(model, evaluator_device, output_transform=output_transform_mock)

Expand Down Expand Up @@ -371,7 +373,8 @@ def _test_create_evaluation_step(

device_type = evaluator_device.type if isinstance(evaluator_device, torch.device) else evaluator_device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, None)
on_mps = "mps" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, on_mps, amp_mode, None)

evaluate_step = supervised_evaluation_step(model, evaluator_device, output_transform=output_transform_mock)

Expand Down Expand Up @@ -452,6 +455,19 @@ def test_create_supervised_trainer_on_cuda():
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Skip if no MPS")
def test_create_supervised_trainer_on_mps():
model_device = trainer_device = "mps"
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
_test_create_supervised_trainer(
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device
)
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda_amp():
Expand Down Expand Up @@ -618,6 +634,19 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
_test_mocked_supervised_evaluator(evaluator_device="cuda")


@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Skip if no MPS Backend")
def test_create_supervised_evaluator_on_mps():
model_device = evaluator_device = "mps"
_test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
_test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)


@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Skip if no MPS Backend")
def test_create_supervised_evaluator_on_mps_with_model_on_cpu():
_test_create_supervised_evaluator(evaluator_device="mps")
_test_mocked_supervised_evaluator(evaluator_device="mps")


@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_evaluator_on_cuda_amp():
Expand Down
Loading