diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index cd74b68e24d..ab4b48660e7 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -547,15 +547,15 @@ def print(self, *args, **kwargs): if self.is_local_main_process: print(*args, **kwargs) - def _prepare_one(self, obj, first_pass=False): + def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: if isinstance(obj, torch.utils.data.DataLoader): - return self.prepare_data_loader(obj) + return self.prepare_data_loader(obj, device_placement=device_placement) elif isinstance(obj, torch.nn.Module): - return self.prepare_model(obj) + return self.prepare_model(obj, device_placement=device_placement) elif isinstance(obj, torch.optim.Optimizer): - optimizer = self.prepare_optimizer(obj) + optimizer = self.prepare_optimizer(obj, device_placement=device_placement) return optimizer # Second pass of preparation: LR scheduler (which need the full list of optimizers) elif isinstance(obj, torch.optim.lr_scheduler._LRScheduler): @@ -602,17 +602,33 @@ def _prepare_fsdp(self, *args): self._optimizers = optimizers return tuple(result) - def prepare(self, *args): + def prepare(self, *args, device_placement=None): """ Prepare all objects passed in `args` for distributed training and mixed precision, then return them in the same order. - Accepts the following type of objects: + Args: + *args (list of objects): + Any of the following type of objects: + + - `torch.utils.data.DataLoader`: PyTorch Dataloader + - `torch.nn.Module`: PyTorch Module + - `torch.optim.Optimizer`: PyTorch Optimizer + - `torch.optim.lr_scheduler._LRScheduler`: PyTorch LR Scheduler - - `torch.utils.data.DataLoader`: PyTorch Dataloader - - `torch.nn.Module`: PyTorch Module - - `torch.optim.Optimizer`: PyTorch Optimizer + device_placement (`List[bool]`, *optional*): + Used to customize whether automatic device placement should be performed for each object passed. Needs + to be a list of the same length as `args`. """ + if device_placement is None: + device_placement = [None for _ in args] + elif self.distributed_type == DistributedType.DEEPSPEED: + raise ValueError("You can't customize device placements with DeepSpeed.") + elif len(device_placement) != len(args): + raise ValueError( + f"`device_placement` should be a list with {len(args)} elements (the number of objects passed)." + ) + if self.distributed_type == DistributedType.FSDP: model_count = 0 optimizer_present = False @@ -656,8 +672,10 @@ def prepare(self, *args): if self.distributed_type == DistributedType.DEEPSPEED: result = self._prepare_deepspeed(*args) else: - result = tuple(self._prepare_one(obj, first_pass=True) for obj in args) - result = tuple(self._prepare_one(obj) for obj in result) + result = tuple( + self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) + ) + result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement)) if tpu_should_fix_optimizer: # 2. grabbing new model parameters @@ -674,7 +692,7 @@ def prepare(self, *args): return result if len(result) > 1 else result[0] - def prepare_model(self, model: torch.nn.Module): + def prepare_model(self, model: torch.nn.Module, device_placement=None): """ Prepares a PyTorch model for training in any distributed setup. It is recommended to use [`Accelerator.prepare`] instead. @@ -682,9 +700,13 @@ def prepare_model(self, model: torch.nn.Module): Args: model (`torch.nn.Module`): A PyTorch model to prepare + device_placement (`bool`, *optional*): + Whether or not to place the model on the proper device. Will default to `self.device_placement`. """ + if device_placement is None: + device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP self._models.append(model) - if self.device_placement and self.distributed_type != DistributedType.FSDP: + if device_placement: model = model.to(self.device) if self.distributed_type == DistributedType.MULTI_GPU: kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} @@ -894,7 +916,7 @@ def _prepare_deepspeed(self, *args): ) return tuple(result) - def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader): + def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader, device_placement=None): """ Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use [`Accelerator.prepare`] instead. @@ -902,19 +924,24 @@ def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader): Args: data_loader (`torch.utils.data.DataLoader`): A vanilla PyTorch DataLoader to prepare + device_placement (`bool`, *optional*): + Whether or not to place the batches on the proper device in the prepared dataloader. Will default to + `self.device_placement`. """ + if device_placement is None: + device_placement = self.device_placement if self.distributed_type != DistributedType.TPU else False return prepare_data_loader( data_loader, self.device, num_processes=self.num_processes, process_index=self.process_index, split_batches=self.split_batches, - put_on_device=self.device_placement if self.distributed_type != DistributedType.TPU else False, + put_on_device=device_placement, rng_types=self.rng_types.copy(), dispatch_batches=self.dispatch_batches, ) - def prepare_optimizer(self, optimizer: torch.optim.Optimizer): + def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=None): """ Prepares a PyTorch Optimizer for training in any distributed setup. It is recommended to use [`Accelerator.prepare`] instead. @@ -922,8 +949,12 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer): Args: optimizer (`torch.optim.Optimizer`): A vanilla PyTorch optimizer to prepare + device_placement (`bool`, *optional*): + Whether or not to place the optimizer on the proper device. Will default to `self.device_placement`. """ - optimizer = AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler) + if device_placement is None: + device_placement = self.device_placement + optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=self.scaler) self._optimizers.append(optimizer) return optimizer