Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restoring optimizer states (with DeepSpeed plugin used) #242

Closed
csarron opened this issue Jan 30, 2022 · 5 comments
Closed

restoring optimizer states (with DeepSpeed plugin used) #242

csarron opened this issue Jan 30, 2022 · 5 comments

Comments

@csarron
Copy link
Contributor

csarron commented Jan 30, 2022

Accelerate is a great library! Thanks for the amazing work!

I was able to save the optimizer/scheduler states using the Accelerator library, but when restoring them back, I got CUDA out of memory error, I guess the optimizer states are not saved properly. I can restore the states without error by setting ckpt_states = torch.load(state_path, map_location='cpu') but not sure if it's correct.

Could you provide some tips or suggestions? (I'm implementing a feature that can fully restore the training, but got into this problem) Thanks.

I guess that saving optimizer states for DeepSpeed is different, I saw the HF Trainer does this, this, and this, but not sure how to borrow that code into mine.

my checkpoint saving function is below:

def save_ckpt(cfg, accelerator, model, optimizer, scheduler, epoch, step, score):
    accelerator.wait_for_everyone()
    ckpt_save_dir = Path(cfg.train.ckpt_save_dir)
    ckpt_file = ckpt_save_dir / "checkpoint.txt"
    ckpt_str = f"epoch-{epoch};step-{step};score-{score:.5f}"
    current_save_dir = ckpt_save_dir / ckpt_str
    if accelerator.is_local_main_process:
        current_save_dir.mkdir(parents=True, exist_ok=True)
        ckpt_file.write_text(ckpt_str)

    # save model
    model_path = current_save_dir / "model"
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        model_path,
        save_function=accelerator.save,
        state_dict=accelerator.get_state_dict(model),
    )

    # save optimizer, scheduler, scaler, epoch, step
    state_path = current_save_dir / "ckpt_states.pth"
    ckpt_states = {
        "scaler": accelerator.scaler.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "epochs": epoch,
        "steps": step,
    }
    accelerator.save(ckpt_states, state_path)

    # save rng states
    rng_states = {
        "python": random.getstate(),
        "numpy": np.random.get_state(),
        "cpu": torch.random.get_rng_state(),
    }
    local_rank = accelerator.local_process_index
    if torch.cuda.is_available():
        if local_rank == -1:
            # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
            rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
        else:
            rng_states["cuda"] = torch.cuda.random.get_rng_state()

    if local_rank == -1:
        torch.save(rng_states, current_save_dir / "rng_state.pth")
    else:
        torch.save(rng_states, current_save_dir / f"rng_state_{local_rank}.pth")
    return current_save_dir

my restore function is like:

    # configure optimizer
    optimizer = AdamW() 

    # config scheduler
    scheduler = get_scheduler(
        name=cfg.train.scheduler.name,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

    if ckpt_states is not None:
        accelerator.scaler.load_state_dict(ckpt_states["scaler"])
        optimizer.load_state_dict(ckpt_states["optimizer"]) ###### this is where the RuntimeError: CUDA out of memory happened !**
        epoch_steps_trained = ckpt_states["steps"]
        epochs_trained = ckpt_states["epochs"]
        scheduler.load_state_dict(ckpt_states["scheduler"])

    if accelerator.is_local_main_process:
        logger.info(f"{num_training_steps=}, {num_warmup_steps=}, {epochs_trained=}, {epoch_steps_trained=}")

    model = accelerator.prepare_model(model)
    optimizer = accelerator.prepare_optimizer(optimizer)

 ## other code that skips train_dataloader for trained_epochs and trained_steps_in_epoch
@csarron csarron changed the title restoring optimizer states caused CUDA out of memory (with DeepSpeed plugin used) restoring optimizer states (with DeepSpeed plugin used) Jan 30, 2022
@sgugger
Copy link
Collaborator

sgugger commented Jan 31, 2022

Hi there! We'll be working on adding a utility to help save/restore checkpoints in the coming month, so it should hopefully be easier to do this when it's there :-)

@muellerzr
Copy link
Collaborator

Closed with #255 🎉

@seanbenhur
Copy link

I noticed the current save_state function won't save epoch/steps count, is there any workaround to save it?, @muellerzr

@muellerzr
Copy link
Collaborator

@seanbenhur once #262 is solved, this will be saved indirectly through the scheduler. Otherwise it's then up to the user to remember what epoch it's on, make note of it, etc.

@seanbenhur
Copy link

Got it, Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants