Skip to content

Commit

Permalink
feat(ckpt): add async upload and ckpt snapshot (opendilab#161)
Browse files Browse the repository at this point in the history
* use fp16 in instruction (opendilab#80)

* delete torch_dtype of README's example code (opendilab#100)

* feat(ckpt): support async ckpt upload and ckpt snapshot

---------

Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>
Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>
  • Loading branch information
4 people authored Aug 8, 2023
1 parent ff0fa76 commit 29d27a6
Show file tree
Hide file tree
Showing 5 changed files with 454 additions and 86 deletions.
29 changes: 18 additions & 11 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,29 @@
NUM_LAYER = 32
VOCAB_SIZE = 103168

MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
# oss: 'boto3:s3://model_weights/XXX'
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"

# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
# Path to save training ckpt.
save_ckpt_folder=SAVE_CKPT_FOLDER,
# Path to continue training ckpt (load model weights and scheduler/context states).
# load_ckpt_folder=LOAD_CKPT_FOLDER,
# Path to initialize with given model weights.
# load_model_only_folder=MODEL_ONLY_FOLDER,
checkpoint_every=50,
# Wheter to load optimizer states when continuing training.
load_optimizer=True,
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training.
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)

TRAIN_FOLDER = "/path/to/dataset"
Expand Down
33 changes: 29 additions & 4 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import init_storage_manager

logger = get_logger(__file__)

Expand Down Expand Up @@ -122,20 +123,44 @@ def args_sanity_check():
if "load_model_only_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("load_model_only_folder", None)

if "async_upload" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("async_upload", False)
else:
if gpc.config.ckpt.async_upload:
assert "save_ckpt_folder" in gpc.config.ckpt
if "boto3:" not in gpc.config.ckpt.save_ckpt_folder:
if gpc.is_rank_for_log():
logger.warning(
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
)
gpc.config.ckpt.async_upload = False
else:
if "async_upload_tmp_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")

if "snapshot_ckpt_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot")

if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
assert gpc.config.ckpt.oss_snapshot_freq > 0

assert not (
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."

gpc.config.ckpt._add_item(
"enable_ckpt", gpc.config.ckpt.save_ckpt_folder is not None and gpc.config.ckpt.checkpoint_every > 0
)
if "enable_save_ckpt" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("enable_save_ckpt", False)

if gpc.is_rank_for_log():
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_ckpt}")
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_save_ckpt}")
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")

# initialization storage manager
init_storage_manager(gpc.config.ckpt)

# tensorboard writer config
if "enable_tb" not in gpc.config:
gpc.config._add_item("enable_tb", True)
Expand Down
89 changes: 88 additions & 1 deletion internlm/utils/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import os
import time
from enum import Enum
from typing import Dict

import torch
Expand All @@ -15,10 +16,22 @@
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
from internlm.utils.storage_manager import (
get_fns,
get_storage_manager,
llm_load,
llm_save,
)

logger = get_logger(__file__)

quit_signal_handler = None


class CheckpointType(Enum):
NORMAL_CHECKPOINT = 1
SNAPSHOT_CHECKPOINT = 2


def get_model_topology(model):
"""
Expand Down Expand Up @@ -289,3 +302,77 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train

if gpc.is_rank_for_log():
logger.info(f"reload load_scheduler:{lr_scheduler}")


class CheckpointSaveManager:
"""StorageManagerContext"""

def __init__(
self,
ckpt_config,
model,
optimizer,
lr_scheduler,
model_config,
) -> None:
"""
CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous
upload mode, you must call wait_async_upload_finish at the end of the program to wait
for the asynchronous ckpt upload to complete.
Args:
ckpt_config (dict): model checkpoint config.
model (nn.module): model obj
optimizer (object): optimzier obj.
lr_scheduler (object): lr_scheduler obj.
model_config (dict): model config.
"""
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
self.checkpoint_every = ckpt_config.checkpoint_every
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
self.storage_manager = get_storage_manager()
self.snapshot_counter = 0

self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.model_config = model_config

def try_save_checkpoint(self, train_state):
if not self.enable_save_ckpt:
return

save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
if train_state.step_count % self.checkpoint_every == 0:
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
if save_ckpts is False:
if quit_signal_handler is not None:
save_ckpts, save_type = quit_signal_handler(train_state)

if save_ckpts:
# Wait for the previous round of asynchronous upload storage to complete.
self.storage_manager.wait()
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
# Snapshot number, with only two snapshots written alternately.
self.snapshot_counter = (self.snapshot_counter + 1) % 2
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
else:
save_ckpt_folder = self.save_ckpt_folder

save_checkpoint(
folder=save_ckpt_folder,
model=self.model,
optimizer=self.optimizer,
scheduler=self.lr_scheduler,
train_state=train_state,
model_config=self.model_config,
)

def wait_async_upload_finish(self):
"""wait for all checkpoint uploads to be completed"""
self.storage_manager.wait()
torch.distributed.barrier()
Loading

0 comments on commit 29d27a6

Please sign in to comment.