Skip to content

Commit

Permalink
[Infer] Delete generate_rank_mapping when export multi cards model (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored Dec 3, 2024
1 parent c4d79f4 commit f7da5a6
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 24 deletions.
1 change: 0 additions & 1 deletion llm/predict/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def main():
predictor.model.generation_config.save_pretrained(export_args.output_path)

predictor.tokenizer.save_pretrained(export_args.output_path)
llm_utils.generate_rank_mapping(os.path.join(export_args.output_path, "rank_mapping.csv"))

if tensor_parallel_degree > 1:
export_args.output_path = os.path.join(export_args.output_path, f"rank_{tensor_parallel_rank}")
Expand Down
23 changes: 0 additions & 23 deletions paddlenlp/trl/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,29 +267,6 @@ def get_infer_model_path(input_dir, model_prefix):
return os.path.join(input_dir, model_prefix)


def generate_rank_mapping(output_filename):
ring_id = -1
try:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
ring_id = model_parallel_group.id
except Exception:
pass

if ring_id == -1:
return

world_size = dist.get_world_size()
with open(output_filename, "w") as f:
f.write("[ring_id -> ranks]\n")
f.write(",".join(map(str, [0] + list(range(world_size)))) + "\n")
f.write(",".join(map(str, [ring_id] + list(range(world_size)))) + "\n")

f.write("[rank -> ring_ids]\n")
for i in range(world_size):
f.write("{},0,{}\n".format(i, ring_id))


def deserialize_from_file(fp):
x_type = fp.read(1)
x_type_out = struct.unpack("c", x_type)[0]
Expand Down

0 comments on commit f7da5a6

Please sign in to comment.