Skip to content

Commit

Permalink
resume step in param group
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Mar 12, 2024
1 parent 26b9ea4 commit 2943e6a
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,7 +2785,7 @@ def load_checkpoint(self,
if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'] + 1)
self.update_optimizer_step(step=client_states['iteration'])

return load_path, client_states

Expand Down Expand Up @@ -2966,7 +2966,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
def update_optimizer_step(self, step):

def set_step(d):
if isinstance(d['step'], torch.Tensor):
if 'step' in d and isinstance(d['step'], torch.Tensor):
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
else:
d['step'] = step
Expand All @@ -2975,8 +2975,7 @@ def set_step(d):
base_optimizer = optimizer.optimizer
state = base_optimizer.state
for group in optimizer.param_groups:
if 'step' in group:
set_step(group)
set_step(group)
for p in group['params']:
if p in state and len(state[p]) > 0 and 'step' in state[p]:
set_step(state[p])
Expand Down

0 comments on commit 2943e6a

Please sign in to comment.