Skip to content

Commit

Permalink
Introduce no_sync context wrapper + clean up some more warnings for…
Browse files Browse the repository at this point in the history
… DDP (#428)
  • Loading branch information
muellerzr authored Jun 8, 2022
1 parent b2afd4e commit 1424a8e
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 26 deletions.
20 changes: 20 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import gc
import math
import os
Expand Down Expand Up @@ -355,6 +356,25 @@ def _goes_first(self, is_main):
if is_main:
self.wait_for_everyone()

@contextmanager
def no_sync(self, model):
"""
A context manager to disable gradient synchronizations across DDP processes by calling
`torch.nn.parallel.DistributedDataParallel.no_sync`.
If `model` is not in DDP, this context manager does nothing
Args:
model (`torch.nn.Module`):
PyTorch Module that was prepared with `Accelerator.prepare`
"""
context = contextlib.nullcontext
if self.num_processes > 1:
context = getattr(model, "no_sync", context)

with context():
yield

def print(self, *args, **kwargs):
"""
Use in replacement of `print()` to only print once per server.
Expand Down
9 changes: 2 additions & 7 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
DistributedType,
PrecisionType,
PrepareForLaunch,
get_launch_prefix,
is_deepspeed_available,
is_sagemaker_available,
)
from accelerate.utils.versions import is_torch_version


def launch_command_parser(subparsers=None):
Expand Down Expand Up @@ -251,12 +251,7 @@ def simple_launcher(args):


def multi_gpu_launcher(args):
if is_torch_version(">=", "1.10.0"):
cmd = ["torchrun"]
elif is_torch_version(">=", "1.9.0"):
cmd = [sys.executable, "-m", "torch.distributed.run"]
else:
cmd = [sys.executable, "-m", "torch.distributed.launch", "--use_env"]
cmd = get_launch_prefix()
if args.num_machines > 1:
cmd.extend(
[
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
require_cpu,
require_cuda,
require_multi_gpu,
require_single_gpu,
require_tpu,
slow,
)
from .training import RegressionDataset, RegressionModel


from .scripts import test_script, test_sync # isort:skip
File renamed without changes.
130 changes: 130 additions & 0 deletions src/accelerate/test_utils/scripts/test_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from accelerate import Accelerator
from accelerate.test_utils import RegressionDataset, RegressionModel
from accelerate.utils import DistributedType, set_seed


def step_model(model, input, target, accelerator):
model.train()
output = model(input)
loss = F.mse_loss(output, target.to(output.device))
accelerator.backward(loss)


def get_training_setup(accelerator):
"Returns everything needed to perform basic training"
set_seed(42)
model = RegressionModel()
model.to(accelerator.device)
dset = RegressionDataset()
dataloader = DataLoader(dset, batch_size=16)
# Make a copy of `model`
ddp_model, dataloader = accelerator.prepare(deepcopy(model), dataloader)
# Use a single batch for all of the tests
ddp_input, ddp_target = next(iter(dataloader)).values()
return model, ddp_model, ddp_input, ddp_target


def test_noop_sync(accelerator):
# Test when on a single CPU or GPU that the context manager does nothing
model, ddp_model, ddp_input, ddp_target = get_training_setup(accelerator)
for iteration in range(3):
# Gather the distributed inputs and targs for the base model
input, target = accelerator.gather((ddp_input, ddp_target))
input, target = input.to(accelerator.device), target.to(accelerator.device)
# Perform our initial ground truth step in non "DDP"
step_model(model, input, target, accelerator)
# Do "gradient accumulation" (noop)
if iteration % 2 == 0:
# Accumulate grads locally
with accelerator.no_sync(ddp_model):
step_model(ddp_model, ddp_input, ddp_target, accelerator)
else:
# Sync grads
step_model(ddp_model, ddp_input, ddp_target, accelerator)

# Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
if not param.requires_grad:
continue
assert torch.allclose(
param.grad, ddp_param.grad
), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"

# Shuffle ddp_input on each iteration
torch.manual_seed(1337 + iteration)
ddp_input = ddp_input[torch.randperm(16)]


def test_distributed_sync(accelerator):
# Test on distributed setup that context manager behaves properly
model, ddp_model, ddp_input, ddp_target = get_training_setup(accelerator)
for iteration in range(3):
# Gather the distributed inputs and targs for the base model
input, target = accelerator.gather((ddp_input, ddp_target))
input, target = input.to(accelerator.device), target.to(accelerator.device)
# Perform our initial ground truth step in non "DDP"
step_model(model, input, target, accelerator)
# Do "gradient accumulation" (noop)
if iteration % 2 == 0:
# Accumulate grads locally
with accelerator.no_sync(ddp_model):
step_model(ddp_model, ddp_input, ddp_target, accelerator)
else:
# Sync grads
step_model(ddp_model, ddp_input, ddp_target, accelerator)

# DDP model and model should only be in sync when not (iteration % 2 == 0)
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
if not param.requires_grad:
continue
if iteration % 2 == 0:
# Grads should not be in sync
assert (
torch.allclose(param.grad, ddp_param.grad) is False
), f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
else:
# Grads should be in sync
assert (
torch.allclose(param.grad, ddp_param.grad) is True
), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"

# Shuffle ddp_input on each iteration
torch.manual_seed(1337 + iteration)
ddp_input = ddp_input[torch.randperm(16)]


def main():
accelerator = Accelerator()
state = accelerator.state
if state.distributed_type == DistributedType.NO:
if state.local_process_index == 0:
print("**NOOP `no_sync` gradient accumulation**")
test_noop_sync(accelerator)
if state.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_CPU):
if state.local_process_index == 0:
print("**Distributed `no_sync` gradient accumulation**")
test_distributed_sync(accelerator)


if __name__ == "__main__":
main()
8 changes: 8 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ def require_tpu(test_case):
return unittest.skipUnless(is_tpu_available(), "test requires TPU")(test_case)


def require_single_gpu(test_case):
"""
Decorator marking a test that requires CUDA on a single GPU. These tests are skipped when there are no GPU
available or number of GPUs is more than one.
"""
return unittest.skipUnless(torch.cuda.device_count() == 1, "test requires a GPU")(test_case)


def require_multi_gpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
DummyScheduler,
)

from .launch import PrepareForLaunch
from .launch import PrepareForLaunch, get_launch_prefix
from .memory import find_executable_batch_size
from .other import (
extract_model_from_parallel,
Expand Down
16 changes: 16 additions & 0 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,28 @@
# limitations under the License.

import os
import sys

import torch

from ..utils import is_torch_version
from .dataclasses import DistributedType


def get_launch_prefix():
"""
Grabs the correct launcher for starting a distributed command, such as either `torchrun`, `python -m
torch.distributed.run`, etc
"""
if is_torch_version(">=", "1.10.0"):
cmd = ["torchrun"]
elif is_torch_version(">=", "1.9.0"):
cmd = [sys.executable, "-m", "torch.distributed.run"]
else:
cmd = [sys.executable, "-m", "torch.distributed.launch", "--use_env"]
return cmd


class PrepareForLaunch:
"""
Prepare a function that will launched in a distributed setup.
Expand Down
56 changes: 56 additions & 0 deletions tests/test_grad_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
import unittest

import torch

import accelerate
from accelerate import debug_launcher
from accelerate.test_utils import (
execute_subprocess_async,
require_cpu,
require_multi_gpu,
require_single_gpu,
test_sync,
)
from accelerate.utils import get_launch_prefix, patch_environment


class SyncScheduler(unittest.TestCase):
def setUp(self):
mod_file = inspect.getfile(accelerate.test_utils)
self.test_file_path = os.path.sep.join(mod_file.split(os.path.sep)[:-1] + ["scripts", "test_sync.py"])

@require_cpu
def test_gradient_sync_single_cpu_noop(self):
debug_launcher(test_sync.main)
debug_launcher(test_sync.main, num_processes=1)

@require_cpu
def test_gradient_sync_multi_cpu(self):
debug_launcher(test_sync.main)

@require_single_gpu
def test_gradient_sync_single_gpu(self):
debug_launcher(test_sync.main, num_processes=1)

@require_multi_gpu
def test_gradient_sync_multi_gpu(self):
print(f"Found {torch.cuda.device_count()} devices.")
cmd = get_launch_prefix() + [f"--nproc_per_node={torch.cuda.device_count()}", self.test_file_path]
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd, env=os.environ.copy())
26 changes: 8 additions & 18 deletions tests/test_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,33 @@

import inspect
import os
import sys
import unittest

import torch

import accelerate
from accelerate import Accelerator
from accelerate.test_utils import execute_subprocess_async, require_multi_gpu
from accelerate.utils import get_launch_prefix, patch_environment


class MultiGPUTester(unittest.TestCase):
def setUp(self):
mod_file = inspect.getfile(accelerate.test_utils)
self.test_file_path = os.path.sep.join(mod_file.split(os.path.sep)[:-1] + ["test_script.py"])
self.test_file_path = os.path.sep.join(mod_file.split(os.path.sep)[:-1] + ["scripts", "test_script.py"])

@require_multi_gpu
def test_multi_gpu(self):
print(f"Found {torch.cuda.device_count()} devices.")
distributed_args = f"""
-m torch.distributed.launch
--nproc_per_node={torch.cuda.device_count()}
--use_env
{self.test_file_path}
""".split()
cmd = [sys.executable] + distributed_args
execute_subprocess_async(cmd, env=os.environ.copy())
cmd = get_launch_prefix() + [self.test_file_path]
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd, env=os.environ.copy())

@require_multi_gpu
def test_pad_across_processes(self):
distributed_args = f"""
-m torch.distributed.launch
--nproc_per_node={torch.cuda.device_count()}
--use_env
{inspect.getfile(self.__class__)}
""".split()
cmd = [sys.executable] + distributed_args
execute_subprocess_async(cmd, env=os.environ.copy())
cmd = get_launch_prefix() + [inspect.getfile(self.__class__)]
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd, env=os.environ.copy())


if __name__ == "__main__":
Expand Down

0 comments on commit 1424a8e

Please sign in to comment.