Skip to content

Commit

Permalink
fix: saving model weights (#556)
Browse files Browse the repository at this point in the history
* fix: saving model weights

checkpointing not saving model weights if calling `accelerator.prepare_model` instead of `accelerator.prepare`
resolves issue: #555

* fix: saveing model weights for optimizer and scheduler
  • Loading branch information
csarron authored Jul 26, 2022
1 parent cc10071 commit 91ff425
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,16 +501,13 @@ def _prepare_one(self, obj, first_pass=False):
if isinstance(obj, torch.utils.data.DataLoader):
return self.prepare_data_loader(obj)
elif isinstance(obj, torch.nn.Module):
self._models.append(obj)
return self.prepare_model(obj)
elif isinstance(obj, torch.optim.Optimizer):
optimizer = self.prepare_optimizer(obj)
self._optimizers.append(optimizer)
return optimizer
# Second pass of preparation: LR scheduler (which need the full list of optimizers)
elif isinstance(obj, torch.optim.lr_scheduler._LRScheduler):
scheduler = self.prepare_scheduler(obj)
self._schedulers.append(scheduler)
return scheduler
# Return the unprocessed object if previous criteria was not met
return obj
Expand Down Expand Up @@ -625,6 +622,7 @@ def prepare(self, *args):
return result if len(result) > 1 else result[0]

def prepare_model(self, model):
self._models.append(model)
if self.device_placement and self.distributed_type != DistributedType.FSDP:
model = model.to(self.device)
if self.distributed_type == DistributedType.MULTI_GPU:
Expand Down Expand Up @@ -837,7 +835,9 @@ def prepare_data_loader(self, data_loader):
)

def prepare_optimizer(self, optimizer):
return AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler)
optimizer = AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler)
self._optimizers.append(optimizer)
return optimizer

def prepare_scheduler(self, scheduler):
# We try to find the optimizer associated with `scheduler`, the default is the full list.
Expand All @@ -846,13 +846,14 @@ def prepare_scheduler(self, scheduler):
if getattr(scheduler, "optimizer", None) == opt.optimizer:
optimizer = opt
break

return AcceleratedScheduler(
scheduler = AcceleratedScheduler(
scheduler,
optimizer,
step_with_optimizer=self.step_scheduler_with_optimizer,
split_batches=self.split_batches,
)
self._schedulers.append(scheduler)
return scheduler

def backward(self, loss, **kwargs):
"""
Expand Down

0 comments on commit 91ff425

Please sign in to comment.