Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Checkpoint compression] Support sharding stage1 v2 #9817

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions paddlenlp/trainer/unified_checkpoint/load_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,6 @@ def _remove_unused_keys(


def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
# Special process with split param.
if is_sharding_split_param_mode(args):
returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint)
return returned_optim_state_dict

# init and get optimizer LR_Scheduler
returned_optim_state_dict = nested_copy(optimizer.state_dict())

if not safe_serialization:
index_filename, index_filename_master_weights = (
PADDLE_OPTIMIZER_INDEX_NAME,
Expand All @@ -165,6 +157,23 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
else:
index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME

with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
index = json.loads(f.read())

ckpt_quant_stage = "O0"
if "ckpt_quant_stage" in index:
ckpt_quant_stage = index["ckpt_quant_stage"]

# Special process with split param.
if is_sharding_split_param_mode(args):
returned_optim_state_dict = load_unified_optimizer_split_param(
args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage
)
return returned_optim_state_dict

# init and get optimizer LR_Scheduler
returned_optim_state_dict = nested_copy(optimizer.state_dict())

resolved_archive_file, sharded_metadata = get_optimizer_shard_files(
optimizer_path=resume_from_checkpoint,
index_filename=os.path.join(resume_from_checkpoint, index_filename),
Expand All @@ -184,13 +193,6 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")

with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
index = json.loads(f.read())

ckpt_quant_stage = "O0"
if "ckpt_quant_stage" in index:
ckpt_quant_stage = index["ckpt_quant_stage"]

# update has_master_weights and index_filename_master_weights
# 1. if the master weight exists, only has_master_weights is set True and loaded when needed
# 2. if master weight does not exist, convert model weight to master weight when needed
Expand Down
71 changes: 59 additions & 12 deletions paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,25 @@
get_expected_state_dict,
get_optimizer_shard_files,
mapping_optimizer_tp_actions,
update_master_weight_status,
)

__all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"]


def merge_splited_param(
state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False
state_dict,
partial_tensor_list,
param_shape_info,
send_table,
recv_table,
is_master_weights=False,
ckpt_quant_stage="O0",
):
"""Merge the splited param in sharding group."""
global_rank = dist.get_rank()
for key in list(state_dict.keys()):
if state_dict[key].numel().item() == 1: # for example: beta1, beta2
if int(state_dict[key].numel()) == 1: # for example: beta1, beta2
continue

static_name = key if is_master_weights else generate_base_static_name(key)[0]
Expand Down Expand Up @@ -89,10 +96,21 @@ def merge_splited_param(
)
dist.stream.send(tensor, dst=recv_rank)
state_dict.pop(key)

if ckpt_quant_stage != "O0":
for key in list(state_dict.keys()):
if int(state_dict[key].numel()) == 1: # for example: beta1, beta2
static_name = key if is_master_weights else generate_base_static_name(key)[0]
if static_name in partial_tensor_list:
recv_rank = recv_table[static_name]
send_info = send_table[static_name]
if global_rank != recv_rank:
state_dict.pop(key)

return state_dict


def gather_splited_param_for_optimizer(optimizer):
def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
hcg = fleet.get_hybrid_communicate_group()
sharding_group = hcg.get_sharding_parallel_group()
global_rank = dist.get_rank()
Expand Down Expand Up @@ -127,7 +145,7 @@ def gather_splited_param_for_optimizer(optimizer):
for key in list(optim_state_dict.keys()):
static_name, _ = generate_base_static_name(key)
if static_name in param_slice_info.keys():
if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2
if int(optim_state_dict[key].numel()) == 1: # for example: beta1, beta2
continue
begin, end = param_slice_info[static_name]
shape, numel, _, _ = param_shape_info[static_name]
Expand All @@ -149,13 +167,15 @@ def gather_splited_param_for_optimizer(optimizer):
recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor
send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist]

merge_splited_param(optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False)
merge_splited_param(
optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False, ckpt_quant_stage
)
if master_weights is not None:
merge_splited_param(master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True)
return optim_state_dict, master_weights


def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint):
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
returned_optim_state_dict = nested_copy(optimizer.state_dict())

index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
Expand Down Expand Up @@ -208,6 +228,10 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")

has_master_weights, index_filename_master_weights = update_master_weight_status(
args, optimizer, has_master_weights, safe_serialization=True
)

if has_master_weights:
returned_optim_state_dict["master_weights"] = {}
resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files(
Expand All @@ -217,7 +241,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
if len(resolved_archive_file_mw) > 1:
resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards")

def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False):
def load_resolved_archive_file(
resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False, ckpt_quant_stage="O0"
):
returned_state_dict = {}

if model.config.tensor_parallel_degree > 1:
Expand All @@ -232,24 +258,38 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]):
continue
if model.config.tensor_parallel_degree > 1:
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu")
state_dict = load_state_dict(
shard_file,
tp_actions,
expected_keys,
device="cpu",
ckpt_quant_stage=ckpt_quant_stage,
)
else:
state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu")
state_dict = load_state_dict(
shard_file,
None,
expected_keys,
device="cpu",
ckpt_quant_stage=ckpt_quant_stage,
)
returned_state_dict.update(state_dict)
del state_dict
gc.collect()

return returned_state_dict

# get tp params
state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim)
state_dict_optim = load_resolved_archive_file(
resolved_archive_file, sharded_metadata, expected_keys_optim, ckpt_quant_stage=ckpt_quant_stage
)

# need to split param for different sharding rank, maybe need to deal with oom issue.
for key in list(state_dict_optim.keys()):
key_name = key.split("/")
static_name = struct2static_name_mappings.get(key_name[0], None)

if state_dict_optim[key].numel().item() > 1:
if int(state_dict_optim[key].numel()) > 1:
begin, end = param_slice_info[static_name]
shape, numel, index, padded_size = param_shape_info[static_name]
state_dict_optim[key] = state_dict_optim[key].reshape([-1])
Expand Down Expand Up @@ -284,7 +324,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected

for key in list(state_dict_master_weight.keys()):
static_name = struct2static_name_mappings.get(key, None)
if state_dict_master_weight[key].numel().item() > 1:
if int(state_dict_master_weight[key].numel()) > 1:
begin, end = param_slice_info[static_name]
shape, numel, index, padded_size = param_shape_info[static_name]
state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1])
Expand All @@ -303,6 +343,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
paddle.framework._current_expected_place(), False
)
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)

# master weight cast (only in remove_master_weight)
if returned_optim_state_dict["master_weights"][static_name].dtype != paddle.float32:
returned_optim_state_dict["master_weights"][static_name] = paddle.cast(
returned_optim_state_dict["master_weights"][static_name], dtype=paddle.float32
)

returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

return returned_optim_state_dict
4 changes: 3 additions & 1 deletion paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
return

if is_sharding_split_param_mode(self.args):
optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer)
optim_state_dict, master_weights = gather_splited_param_for_optimizer(
optimizer, self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0"
)
else:
optim_state_dict = nested_copy(optimizer.state_dict())
master_weights = None
Expand Down