From 91ff425bb0a21dad6fecb75480d0a05e15c4a837 Mon Sep 17 00:00:00 2001 From: Qingqing Cao Date: Tue, 26 Jul 2022 05:44:09 -0700 Subject: [PATCH] fix: saving model weights (#556) * fix: saving model weights checkpointing not saving model weights if calling `accelerator.prepare_model` instead of `accelerator.prepare` resolves issue: https://github.com/huggingface/accelerate/issues/555 * fix: saveing model weights for optimizer and scheduler --- src/accelerate/accelerator.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 5f32a83d19a..b47a2dc1539 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -501,16 +501,13 @@ def _prepare_one(self, obj, first_pass=False): if isinstance(obj, torch.utils.data.DataLoader): return self.prepare_data_loader(obj) elif isinstance(obj, torch.nn.Module): - self._models.append(obj) return self.prepare_model(obj) elif isinstance(obj, torch.optim.Optimizer): optimizer = self.prepare_optimizer(obj) - self._optimizers.append(optimizer) return optimizer # Second pass of preparation: LR scheduler (which need the full list of optimizers) elif isinstance(obj, torch.optim.lr_scheduler._LRScheduler): scheduler = self.prepare_scheduler(obj) - self._schedulers.append(scheduler) return scheduler # Return the unprocessed object if previous criteria was not met return obj @@ -625,6 +622,7 @@ def prepare(self, *args): return result if len(result) > 1 else result[0] def prepare_model(self, model): + self._models.append(model) if self.device_placement and self.distributed_type != DistributedType.FSDP: model = model.to(self.device) if self.distributed_type == DistributedType.MULTI_GPU: @@ -837,7 +835,9 @@ def prepare_data_loader(self, data_loader): ) def prepare_optimizer(self, optimizer): - return AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler) + optimizer = AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler) + self._optimizers.append(optimizer) + return optimizer def prepare_scheduler(self, scheduler): # We try to find the optimizer associated with `scheduler`, the default is the full list. @@ -846,13 +846,14 @@ def prepare_scheduler(self, scheduler): if getattr(scheduler, "optimizer", None) == opt.optimizer: optimizer = opt break - - return AcceleratedScheduler( + scheduler = AcceleratedScheduler( scheduler, optimizer, step_with_optimizer=self.step_scheduler_with_optimizer, split_batches=self.split_batches, ) + self._schedulers.append(scheduler) + return scheduler def backward(self, loss, **kwargs): """