diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index f1882bbea3a1..d7abe38c936a 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -63,8 +63,8 @@ def save_val(val, dir, key, tp_num=None): - if tp_num: - key += f".{tp_num}.bin" + suffix = f".{tp_num}.bin" if tp_num else '' + tp_key = key + suffix global weights_dict # Transpose linear layer weights to the correct shape. @@ -74,14 +74,14 @@ def save_val(val, dir, key, tp_num=None): val = val.reshape(val.shape[0], -1) val = torch.transpose(val, 0, 1) if key not in weights_dict: - weights_dict[key] = torch.empty( + weights_dict[tp_key] = torch.empty( val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True ) - weights_dict[key].copy_(val, non_blocking=True) + weights_dict[tp_key].copy_(val, non_blocking=True) else: if len(val.shape) >= 2: val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0])) - weights_dict[key] = val + weights_dict[tp_key] = val def save_split(split_vals, dir, key, i, split_factor): @@ -91,11 +91,10 @@ def save_split(split_vals, dir, key, i, split_factor): def save_expert_split(split_vals, dir, key, i, split_factor): for j, val in enumerate(split_vals): + suffix = f".{tp_num}.bin" if tp_num else '' tp_num = i * split_factor + j - if tp_num: - key += f".{tp_num}.bin" global weights_dict - weights_dict[key] = val + weights_dict[key + suffix] = val def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):