Skip to content

Commit

Permalink
Revert "Revert "Load optimizer state one at a time""
Browse files Browse the repository at this point in the history
This reverts commit 9527ba3.
  • Loading branch information
dirkgr committed Nov 8, 2023
1 parent 4dde969 commit efe2292
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,7 @@ def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[
state[k] = v.to(device="cpu")
torch.cuda.empty_cache()
flattened_osd = fix_optim_state_dict(optim, flattened_osd)

for turn in range(get_local_world_size()):
log.info("Loading flattened optimizer state turn %d", turn)
if turn == get_local_rank():
optim.load_state_dict(flattened_osd)
del flattened_osd
gc.collect()
barrier()
optim.load_state_dict(flattened_osd)


def save_state_dict(
Expand Down Expand Up @@ -629,12 +622,16 @@ def restore_checkpoint(

# Load optimizer state.
if load_optimizer_state:
log.info("Loading optimizer state...")
optim_state_dict_to_load = self._make_optim_state_dict_compatible(
load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location="cpu"),
og_keys_to_new,
)
load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load)
for turn in range(get_local_world_size()):
log.info("Loading optimizer state turn %d ...", turn)
if turn == get_local_rank():
optim_state_dict_to_load = self._make_optim_state_dict_compatible(
load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location="cpu"),
og_keys_to_new,
)
load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load)
gc.collect()
barrier()

# Load other state.
try:
Expand Down

0 comments on commit efe2292

Please sign in to comment.