Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom device placements for different objects #716

Merged
merged 4 commits into from
Sep 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,15 +547,15 @@ def print(self, *args, **kwargs):
if self.is_local_main_process:
print(*args, **kwargs)

def _prepare_one(self, obj, first_pass=False):
def _prepare_one(self, obj, first_pass=False, device_placement=None):
# First pass of preparation: DataLoader, model, optimizer
if first_pass:
if isinstance(obj, torch.utils.data.DataLoader):
return self.prepare_data_loader(obj)
return self.prepare_data_loader(obj, device_placement=device_placement)
elif isinstance(obj, torch.nn.Module):
return self.prepare_model(obj)
return self.prepare_model(obj, device_placement=device_placement)
elif isinstance(obj, torch.optim.Optimizer):
optimizer = self.prepare_optimizer(obj)
optimizer = self.prepare_optimizer(obj, device_placement=device_placement)
return optimizer
# Second pass of preparation: LR scheduler (which need the full list of optimizers)
elif isinstance(obj, torch.optim.lr_scheduler._LRScheduler):
Expand Down Expand Up @@ -602,17 +602,33 @@ def _prepare_fsdp(self, *args):
self._optimizers = optimizers
return tuple(result)

def prepare(self, *args):
def prepare(self, *args, device_placement=None):
"""
Prepare all objects passed in `args` for distributed training and mixed precision, then return them in the same
order.

Accepts the following type of objects:
Args:
*args (list of objects):
Any of the following type of objects:

- `torch.utils.data.DataLoader`: PyTorch Dataloader
- `torch.nn.Module`: PyTorch Module
- `torch.optim.Optimizer`: PyTorch Optimizer
- `torch.optim.lr_scheduler._LRScheduler`: PyTorch LR Scheduler

- `torch.utils.data.DataLoader`: PyTorch Dataloader
- `torch.nn.Module`: PyTorch Module
- `torch.optim.Optimizer`: PyTorch Optimizer
device_placement (`List[bool]`, *optional*):
Used to customize whether automatic device placement should be performed for each object passed. Needs
to be a list of the same length as `args`.
"""
if device_placement is None:
device_placement = [None for _ in args]
elif self.distributed_type == DistributedType.DEEPSPEED:
raise ValueError("You can't customize device placements with DeepSpeed.")
elif len(device_placement) != len(args):
raise ValueError(
f"`device_placement` should be a list with {len(args)} elements (the number of objects passed)."
)

if self.distributed_type == DistributedType.FSDP:
model_count = 0
optimizer_present = False
Expand Down Expand Up @@ -656,8 +672,10 @@ def prepare(self, *args):
if self.distributed_type == DistributedType.DEEPSPEED:
result = self._prepare_deepspeed(*args)
else:
result = tuple(self._prepare_one(obj, first_pass=True) for obj in args)
result = tuple(self._prepare_one(obj) for obj in result)
result = tuple(
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
)
result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))

if tpu_should_fix_optimizer:
# 2. grabbing new model parameters
Expand All @@ -674,17 +692,21 @@ def prepare(self, *args):

return result if len(result) > 1 else result[0]

def prepare_model(self, model: torch.nn.Module):
def prepare_model(self, model: torch.nn.Module, device_placement=None):
"""
Prepares a PyTorch model for training in any distributed setup. It is recommended to use
[`Accelerator.prepare`] instead.

Args:
model (`torch.nn.Module`):
A PyTorch model to prepare
device_placement (`bool`, *optional*):
Whether or not to place the model on the proper device. Will default to `self.device_placement`.
"""
if device_placement is None:
device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP
self._models.append(model)
if self.device_placement and self.distributed_type != DistributedType.FSDP:
if device_placement:
model = model.to(self.device)
if self.distributed_type == DistributedType.MULTI_GPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
Expand Down Expand Up @@ -894,36 +916,45 @@ def _prepare_deepspeed(self, *args):
)
return tuple(result)

def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader):
def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader, device_placement=None):
"""
Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use
[`Accelerator.prepare`] instead.

Args:
data_loader (`torch.utils.data.DataLoader`):
A vanilla PyTorch DataLoader to prepare
device_placement (`bool`, *optional*):
Whether or not to place the batches on the proper device in the prepared dataloader. Will default to
`self.device_placement`.
"""
if device_placement is None:
device_placement = self.device_placement if self.distributed_type != DistributedType.TPU else False
return prepare_data_loader(
data_loader,
self.device,
num_processes=self.num_processes,
process_index=self.process_index,
split_batches=self.split_batches,
put_on_device=self.device_placement if self.distributed_type != DistributedType.TPU else False,
put_on_device=device_placement,
rng_types=self.rng_types.copy(),
dispatch_batches=self.dispatch_batches,
)

def prepare_optimizer(self, optimizer: torch.optim.Optimizer):
def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=None):
"""
Prepares a PyTorch Optimizer for training in any distributed setup. It is recommended to use
[`Accelerator.prepare`] instead.

Args:
optimizer (`torch.optim.Optimizer`):
A vanilla PyTorch optimizer to prepare
device_placement (`bool`, *optional*):
Whether or not to place the optimizer on the proper device. Will default to `self.device_placement`.
"""
optimizer = AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler)
if device_placement is None:
device_placement = self.device_placement
optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=self.scaler)
self._optimizers.append(optimizer)
return optimizer

Expand Down