diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index d7abe38c936a..471086d2a333 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -91,8 +91,9 @@ def save_split(split_vals, dir, key, i, split_factor): def save_expert_split(split_vals, dir, key, i, split_factor): for j, val in enumerate(split_vals): - suffix = f".{tp_num}.bin" if tp_num else '' tp_num = i * split_factor + j + suffix = f".{tp_num}.bin" if tp_num else '' + global weights_dict weights_dict[key + suffix] = val