Skip to content

Commit

Permalink
fixed scaling weights
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Kaminski <pikaminski@nvidia.com>
  • Loading branch information
Piotr Kaminski committed Aug 14, 2024
1 parent 5087268 commit 61d0f47
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 74 deletions.
9 changes: 6 additions & 3 deletions nemo/export/trt_llm/converter/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,11 @@ def model_to_trtllm_ckpt(
vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size

if has_lm_head and vocab_size_padded != vocab_size:
pad_width = vocab_size_padded - vocab_size
lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0)
padding = (0, 0, 0, vocab_size_padded - vocab_size)
embedding_key = "transformer.vocab_embedding.weight"
lm_head_weight = torch.nn.functional.pad(lm_head_weight, padding, "constant", 0)
weights_dict[embedding_key] = torch.nn.functional.pad(weights_dict[embedding_key], padding, "constant", 0)


world_size = tensor_parallel_size * pipeline_parallel_size
hidden_act = nemo_model_config.get('activation')
Expand Down Expand Up @@ -161,7 +164,7 @@ def model_to_trtllm_ckpt(
'share_embedding_table': use_embedding_sharing,
'quantization': {
'quant_algo': "FP8" if nemo_model_config.get('fp8', False) else None,
'kv_cache_quant_algo': None,
'kv_cache_quant_algo': None, # TODO maybe "FP8",
},
'bias': nemo_model_config.get('bias'),
'apply_query_key_layer_scaling': False,
Expand Down
73 changes: 43 additions & 30 deletions nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,33 @@ def rename_key_dist_ckpt(old_key: str, layer: int):
return rename_key(new_key)


def load_scaling_factors(model, num_layers, tp_rank, out_dir, split_factor, storage_type, export_config):
starmap_args = []
for key, val in model.items():
if 'extra_state' not in key:
continue

for i in range(num_layers):
starmap_args.append(
(
tp_rank,
out_dir,
split_factor,
rename_key_dist_ckpt(key, i),
[val[i]],
storage_type,
None,
export_config,
{},
)
)

for starmap_arg in starmap_args:
scaling_factors = split_and_save_weight(*starmap_arg)

return scaling_factors


@torch.no_grad()
def convert_model_to_trt_llm_ckpt(
nemo_model_config,
Expand Down Expand Up @@ -186,41 +213,24 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
handle_model_level_weights(model, 0, 0)
model = extract_layers_with_prefix(model, transformer_layer_prefix)

scaling_factors = load_scaling_factors(model, num_layers, tp_rank, out_dir, split_factor, storage_type, export_config)

starmap_args = []
for key, val in model.items():
if 'extra_state' in key:
continue

# Let's rename/map the key to the old layer name previously.
# Since the state dict value has the full layers, let's select the ith layer weights/biases here.
if len(val.size()) == 1:
key_vals = [(rename_key_dist_ckpt(key, 0), val)]
else:
key_vals = [(rename_key_dist_ckpt(key, i), val[i]) for i in range(num_layers)]

for (k, v) in key_vals:
starmap_args.append(
(
tp_rank,
out_dir,
split_factor,
# Let's rename/map the key to the old layer name previously. You can try printing out
# the rename_key output of the old llama checkpoint and compare.
rename_key_dist_ckpt(key, 0),
# Since the state dict value has the full layers, let's select the ith layer weights/biases here.
[val],
storage_type,
None,
export_config,
)
(tp_rank, out_dir, split_factor, k, [v], storage_type, None, export_config, scaling_factors)
)
else:
for i in range(num_layers):
starmap_args.append(
(
tp_rank,
out_dir,
split_factor,
# Let's rename/map the key to the old layer name previously. You can try printing out
# the rename_key output of the old llama checkpoint and compare.
rename_key_dist_ckpt(key, i),
# Since the state dict value has the full layers, let's select the ith layer weights/biases here.
[val[i]],
storage_type,
None,
export_config,
)
)

starmap_args = tqdm(starmap_args, desc="saving weights")

Expand All @@ -239,6 +249,9 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):
model_level_weights[key] = torch.concatenate(values, axis=0)
weights_dict[key] = model_level_weights[key]

for key, value in scaling_factors.items():
weights_dict[key] = value

return weights_dict


Expand Down
62 changes: 36 additions & 26 deletions nemo/export/trt_llm/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ def write_int8(vals, dir, base_key, split_dim, tp_rank, split_factor, kv_cache_o
def get_suffix(key):
return '.' + key.split('.')[-1]

def get_layer_prefix(key):
def get_trt_llm_prefix(key):
layer_num = key.split(".")[1]
return f'transformer.layers.{layer_num}'

def get_new_keyname(key):
layer_prefix = get_layer_prefix(key)
layer_prefix = get_trt_llm_prefix(key)

if ("post_attention_layernorm.weight" in key
or "post_attention_layernorm.bias" in key
Expand Down Expand Up @@ -239,51 +239,62 @@ def get_new_keyname(key):
def is_scaling_factor(key):
return "scale_fwd" in key


def get_scaling_factor_keys(key):
base_key = '.'.join(key.split('.')[:-2]) + '.weight'
base_key = '.'.join(get_new_keyname(base_key).split('.')[:-1])
weight_key = '.'.join(key.split('.')[:-2]) + '.weight'
base_key = '.'.join(get_new_keyname(weight_key).split('.')[:-1])
weight_scale = base_key + '.weights_scaling_factor'
activation_scale = base_key + '.activation_scaling_factor'
return weight_scale, activation_scale


first = True
def handle_scaling_factor(key, val, dir, split_gated_activation):
weights_key, activation_key = get_scaling_factor_keys(key)

activation_factor = 1 / val[0].view(1)
weights_factor = 1 / val[1].view(1)
# weights_factor_2 = 1 / val[2].view(1)
weights_factor_2 = 1 / val[2].view(1)

save_val(torch_to_numpy(activation_factor), dir, activation_key)
save_val(torch_to_numpy(weights_factor), dir, weights_key)
#save_val(torch_to_numpy(weights_factor_2), dir, weights_key + '_2')
# save_val(torch_to_numpy(weights_factor_2), dir, weights_key + '_2')

# global first
# if first:
# first = False
# for i in range(32):
# save_val(torch_to_numpy(weights_factor_2), dir, f'transformer.layers.{i}.attention.kv_cache_scaling_factor')

if split_gated_activation and (("mlp.dense_h_to_4h" in key) or ("mlp.linear_fc1" in key)):
layer_prefix = get_layer_prefix(key)
layer_prefix = get_trt_llm_prefix(key)
mapped_key = f'{layer_prefix}.mlp.gate'
save_val(torch_to_numpy(activation_factor), dir, mapped_key + '.activation_scaling_factor')
save_val(torch_to_numpy(weights_factor), dir, mapped_key + '.weights_scaling_factor')
#save_val(torch_to_numpy(weights_factor_2), dir, mapped_key + '.weights_scaling_factor_2')
# save_val(torch_to_numpy(weights_factor_2), dir, mapped_key + '.weights_scaling_factor_2')

global weights_dict
return weights_dict


def cast_val_datatype(vals, key, storage_type, is_fp8_model):
def cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, scaling_factors):
if is_fp8_model:
fp8_storage_type = torch.float8_e4m3fn
quantized_keys = ['attention.dense', 'attention.linear', 'attention.query_key_value', 'attention.linear_qkv', 'mlp.linear', 'mlp.dense']
quantized_keys = [ k.split('.weights_scaling_factor')[0] for k in scaling_factors.keys() if '.weights_scaling_factor' in k]
for k in quantized_keys:
if k in key:
if k in trt_llm_key:
storage_type = fp8_storage_type
s = scaling_factors[k + '.weights_scaling_factor']
vals = [val.to(torch.float32) / s for val in vals]
break

return [val.to(storage_type) for val in vals]


# Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head
# are not split as there is only one head per key/value.
@torch.no_grad()
def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config):
def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_type, act_range, config, sf):
use_attention_nemo_shape = config.get("use_attention_nemo_shape", False)
split_gated_activation = config.get("split_gated_activation", False)
num_attention_heads = config.get("num_attention_heads", 0)
Expand All @@ -299,7 +310,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t
return handle_scaling_factor(key, vals[0], saved_dir, split_gated_activation)

save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only"
layer_prefix = get_layer_prefix(key)
layer_prefix = get_trt_llm_prefix(key)

if not isinstance(vals, list):
vals = [vals]
Expand All @@ -309,14 +320,13 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t
if "layernorm.weight" in key and config.get("apply_layernorm_1p", False):
vals = [val.float() + 1.0 for val in vals]

vals = cast_val_datatype(vals, key, storage_type, is_fp8_model)
trt_llm_key = get_new_keyname(key)
vals = cast_val_datatype(vals, trt_llm_key, storage_type, is_fp8_model, sf)
if convert_on_device:
assert len(vals) == 1 # Should only convert a single device param per call
assert torch.is_tensor(vals[0])
elif torch.is_tensor(vals[0]):
vals = [torch_to_numpy(val.cpu()) for val in vals]

trt_llm_key = get_new_keyname(key)
if (
"input_layernorm.weight" in key
or "input_layernorm.bias" in key
Expand Down Expand Up @@ -353,7 +363,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t
if act_range is not None and int8_outputs == "all":
base_key = trt_llm_key.replace(".weight", "")
vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode)
write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor)
write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) # is cat dim always defined?

elif (
"mlp.dense_h_to_4h.weight" in key
Expand Down Expand Up @@ -462,13 +472,12 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t
qkv = np.split(val, [q_num, q_num + 1], axis=2)

query_groups_shape = qkv[0].shape
if len(query_groups_shape) > 1:
if (query_groups_shape[1] % split_factor) != 0:
raise Exception(
"Number of query groups of the models is {0}. Please select tensor parallelism size "
"that can split the number of query groups to equal number of query matrices in the "
"each GPU.".format(query_groups_shape[1])
)
if len(query_groups_shape) > 1 and ((query_groups_shape[1] % split_factor) != 0):
raise Exception(
"Number of query groups of the models is {0}. Please select tensor parallelism size "
"that can split the number of query groups to equal number of query matrices in the "
"each GPU.".format(query_groups_shape[1])
)

q_split = np.split(qkv[0], split_factor, axis=1)
k_split = np.split(qkv[1], split_factor, axis=1)
Expand Down Expand Up @@ -538,10 +547,11 @@ def split(v, tp_size, idx, dim=0):
"""Splits the np tensor v on dim and return the idx's slice."""
if tp_size == 1:
return v

if len(v.shape) == 1:
return np.ascontiguousarray(np.split(v, tp_size)[idx])
else:
return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx])

return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx])


def init_model_parallel_from_nemo(reshard_model):
Expand Down
34 changes: 19 additions & 15 deletions nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def unpack_extra_state_key(key):
size = int(key.split('/')[1].split('_')[-1])
return basename, size

def clear_key_basename_from_state_dict(state_dict, basename):
# '/' is important, as scaling factors are saved to basename.scaling_fwd
def clear_loaded_extra_states(state_dict, basename):
to_remove = [k for k in state_dict.keys() if basename + '/' in k]
for key in to_remove:
state_dict.pop(key)
Expand All @@ -105,7 +104,7 @@ def standarize_distributed_scaling_factors(state_dict):
scaling_factors = load_scaling_factors(state_dict, basename, size)
if scaling_factors != []:
state_dict[basename + '.scale_fwd'] = scaling_factors
state_dict = clear_key_basename_from_state_dict(state_dict, basename)
state_dict = clear_loaded_extra_states(state_dict, basename)

return state_dict

Expand Down Expand Up @@ -138,35 +137,42 @@ def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch

def load_sharded_pickle_extra_state_scale(dir):
scales = []
layer_number = 0

i = 0
while pt_file_list := list(dir.glob(f'shard_{i}_*.pt')):
while pt_file_list := list(dir.glob(f'shard_{layer_number}_*.pt')):
pt_file = pt_file_list[0]
checkpoint = torch.load(pt_file)
checkpoint.seek(0)
state_dict = torch.load(checkpoint)
if not 'scale_fwd' in state_dict:
if 'scale_fwd' not in state_dict:
return []
scale = state_dict['scale_fwd'].cpu()
scales.append(scale)
i += 1
layer_number += 1

all_scales = torch.stack(scales)
return all_scales

def contains_extra_states(subdir):
return list(subdir.glob('shard_0_*.pt')) != []

def load_extra_state_from_pickle(sharded_state_dict, subdir):
if scales := load_sharded_pickle_extra_state_scale(subdir):
key = subdir.name + '.scale_fwd'
sharded_state_dict[key] = scales

return sharded_state_dict

def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tensor=True):
sharded_state_dict = {}
for subdir in checkpoint_dir.iterdir():
if not subdir.is_dir():
continue

key = subdir.name
if list(subdir.glob('shard_0_*.pt')):
scales = load_sharded_pickle_extra_state_scale(subdir)
if scales != []:
key = key + '.scale_fwd'
sharded_state_dict[key] = scales
if contains_extra_states(subdir):
sharded_state_dict = load_extra_state_from_pickle(sharded_state_dict, subdir)
elif (subdir / '.zarray').exists():
key = subdir.name
zstore = ZarrPathStore(subdir)
arr = zarr.open(zstore, 'r')

Expand All @@ -179,8 +185,6 @@ def load_sharded_metadata_zarr(checkpoint_dir: Union[Path, TarPath], torch_tenso
sharded_state_dict[key] = torch.from_numpy(arr[:]).view(str_dtype_to_torch(arr.dtype.name))
else:
sharded_state_dict[key] = arr[:]
else:
continue

return sharded_state_dict

Expand Down

0 comments on commit 61d0f47

Please sign in to comment.