Skip to content

Commit

Permalink
bugfix: naming
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Kaminski <pikaminski@nvidia.com>
  • Loading branch information
Piotr Kaminski committed Aug 21, 2024
1 parent 73d9261 commit 84a5e5e
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions nemo/export/trt_llm/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@


def save_val(val, dir, key, tp_num=None):
if tp_num:
key += f".{tp_num}.bin"
suffix = f".{tp_num}.bin" if tp_num else ''
tp_key = key + suffix

global weights_dict
# Transpose linear layer weights to the correct shape.
Expand All @@ -74,14 +74,14 @@ def save_val(val, dir, key, tp_num=None):
val = val.reshape(val.shape[0], -1)
val = torch.transpose(val, 0, 1)
if key not in weights_dict:
weights_dict[key] = torch.empty(
weights_dict[tp_key] = torch.empty(
val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True
)
weights_dict[key].copy_(val, non_blocking=True)
weights_dict[tp_key].copy_(val, non_blocking=True)
else:
if len(val.shape) >= 2:
val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0]))
weights_dict[key] = val
weights_dict[tp_key] = val


def save_split(split_vals, dir, key, i, split_factor):
Expand All @@ -91,11 +91,10 @@ def save_split(split_vals, dir, key, i, split_factor):

def save_expert_split(split_vals, dir, key, i, split_factor):
for j, val in enumerate(split_vals):
suffix = f".{tp_num}.bin" if tp_num else ''
tp_num = i * split_factor + j
if tp_num:
key += f".{tp_num}.bin"
global weights_dict
weights_dict[key] = val
weights_dict[key + suffix] = val


def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
Expand Down

0 comments on commit 84a5e5e

Please sign in to comment.