From a948bfe9b96a062ec55c70a0fe5e9a7f1ac8151e Mon Sep 17 00:00:00 2001 From: ckl117 Date: Mon, 10 Feb 2025 10:48:31 +0800 Subject: [PATCH] support dp --- csrc/gpu/get_output.cc | 9 ++--- csrc/gpu/save_with_output_msg.cc | 9 ++--- llm/predict/predictor.py | 35 ++++++++++++++++--- .../transformers/fused_transformer_layers.py | 12 +++++-- .../transformers/generation_utils.py | 5 ++- .../transformers/qwen2_moe/modeling.py | 1 + paddlenlp/trl/llm_utils.py | 3 +- 7 files changed, 58 insertions(+), 16 deletions(-) diff --git a/csrc/gpu/get_output.cc b/csrc/gpu/get_output.cc index 292bd2b1482c..42389a45254f 100644 --- a/csrc/gpu/get_output.cc +++ b/csrc/gpu/get_output.cc @@ -27,13 +27,14 @@ struct msgdata { }; void GetOutput(const paddle::Tensor& x, - int64_t rank_id, + int64_t tp_rank, + int64_t dp_rank, bool wait_flag) { - if (rank_id > 0) return; + if (tp_rank > 0) return; static struct msgdata msg_rcv; - static key_t key = ftok("./", 1); + static key_t key = ftok("./", dp_rank+1); static int msgid = msgget(key, IPC_CREAT | 0666); @@ -62,7 +63,7 @@ void GetOutput(const paddle::Tensor& x, PD_BUILD_OP(get_output) .Inputs({"x"}) - .Attrs({"rank_id: int64_t", + .Attrs({"tp_rank: int64_t", "dp_rank: int64_t", "wait_flag: bool"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) diff --git a/csrc/gpu/save_with_output_msg.cc b/csrc/gpu/save_with_output_msg.cc index 0c75d4408979..42034e7ac4ef 100644 --- a/csrc/gpu/save_with_output_msg.cc +++ b/csrc/gpu/save_with_output_msg.cc @@ -28,15 +28,16 @@ struct msgdata { void SaveOutMmsg(const paddle::Tensor& x, const paddle::Tensor& not_need_stop, - int64_t rank_id) { - if (rank_id > 0) return; + int64_t tp_rank, + int64_t dp_rank) { + if (tp_rank > 0) return; auto x_cpu = x.copy_to(paddle::CPUPlace(), false); int64_t *x_data = x_cpu.data(); auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), false); bool* not_need_stop_data = not_need_stop_cpu.data(); static struct msgdata msg_sed; - static key_t key = ftok("./", 1); + static key_t key = ftok("./", dp_rank+1); static int msgid = msgget(key, IPC_CREAT | 0666); msg_sed.mtype = 1; @@ -54,7 +55,7 @@ void SaveOutMmsg(const paddle::Tensor& x, PD_BUILD_OP(save_output) .Inputs({"x", "not_need_stop"}) - .Attrs({"rank_id: int64_t"}) + .Attrs({"tp_rank: int64_t", "dp_rank: int64_t"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) .SetKernelFn(PD_KERNEL(SaveOutMmsg)); \ No newline at end of file diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 77017c954f4d..5c6a43d0c90a 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -60,6 +60,7 @@ class PredictorArgument: model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."}) model_prefix: str = field(default="model", metadata={"help": "the prefix name of static model"}) + dp_degree: int = field(default=8, metadata={"help": "The data parallel degree."}) src_length: int = field(default=1024, metadata={"help": "The max length of source text."}) min_length: int = field(default=1, metadata={"help": "the min length for decoding."}) max_length: int = field(default=1024, metadata={"help": "the max length for decoding."}) @@ -1428,11 +1429,14 @@ def predict(): paddle.set_device(predictor_args.device) paddle.set_default_dtype(predictor_args.dtype) - tensor_parallel_degree = paddle.distributed.get_world_size() + world_size = paddle.distributed.get_world_size() + dp_degree = predictor_args.dp_degree + tensor_parallel_degree = world_size // dp_degree + if tensor_parallel_degree > 1: strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { - "dp_degree": 1, + "dp_degree": dp_degree, "mp_degree": tensor_parallel_degree, "pp_degree": 1, "sharding_degree": 1, @@ -1460,14 +1464,32 @@ def predict(): target_texts.append("") else: - source_texts = ["解释一下温故而知新"] * predictor_args.batch_size - target_texts = [""] * predictor_args.batch_size + source_texts = ["解释一下温故而知新", + "你好,你是谁", + "请问中国的首都是哪里呢?", + "请问法国的首都是哪里呢?", + "请问日本的首都是哪里呢?", + "请问南非的首都是哪里呢?", + "请问英国的首都是哪里呢?", + "小日本为什么叫小日本呢?", + ] + target_texts = [""] * len(source_texts) batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size) batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size) + if predictor_args.dp_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + dp_degree = hcg.get_data_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + else: + dp_degree = 1 + dp_id = 0 + with open(model_args.output_file, "w", encoding="utf-8") as f: for bs, batch_source_text in enumerate(batch_source_texts): + if bs % dp_degree != dp_id: + continue logger.info("Start predict") outputs = predictor.predict(batch_source_text) logger.info("End predict") @@ -1487,6 +1509,11 @@ def predict(): if predictor_args.benchmark: benchmark(predictor, predictor_args, model_args) + paddle.distributed.barrier() + import paddle.distributed as dist + data = paddle.to_tensor([1]) + dist.all_reduce(data) + print(data) def benchmark(predictor, predictor_args, model_args): # Just construct a simple benchmark input. We pad input to the src_length. diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index f499c19cb34d..9576e7270369 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations +from paddle.distributed import fleet import os from dataclasses import dataclass @@ -377,6 +378,13 @@ def __init__(self, config: FusedMultiTransformerConfig): self._epsilon = config.epsilon self._residual_alpha = config.residual_alpha self.nranks = config.nranks + + if self.nranks > 1: + dp_degree = fleet.get_hybrid_communicate_group().get_data_parallel_world_size() + self.tp_group = None + if dp_degree > 1: + self.tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + self.norm_type = config.norm_type if self.norm_type == "layernorm": self.norm_func = fused_layer_norm @@ -1434,7 +1442,7 @@ def forward( ) # all_reduce if self.nranks > 1: - dist.all_reduce(out_linear_out) + dist.all_reduce(out_linear_out, group = self.tp_group) # ffn layernorm tmp_out, residual_input = self.compute_ffn_layernorm(out_linear_out, residual_input, i) @@ -1457,7 +1465,7 @@ def forward( # all_reduce if self.nranks > 1: - dist.all_reduce(ffn2_out) + dist.all_reduce(ffn2_out, group = self.tp_group) # norm + residual_add_bias tmp_out, residual_input = self.compute_bias_residual_layernorm( diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index f0a8d1a035c2..5fdec9cdb4e3 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -18,6 +18,7 @@ import paddle import paddle.nn.functional as F +from paddle.distributed import fleet from paddlenlp.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList @@ -733,7 +734,8 @@ def _post_process_( _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) if self.config.tensor_parallel_degree > 1: - paddle.distributed.broadcast(next_tokens, 0) + rank = fleet.get_hybrid_communicate_group().get_data_parallel_rank() * self.config.tensor_parallel_degree + paddle.distributed.broadcast(next_tokens, rank, group = fleet.get_hybrid_communicate_group().get_model_parallel_group()) with paddle.base.framework._stride_in_no_check_dy2st_diff(): from paddlenlp_ops import update_inputs_v2 @@ -760,6 +762,7 @@ def _post_process_( next_tokens, model_kwargs["not_need_stop"], self.config.tensor_parallel_rank, + fleet.get_hybrid_communicate_group().get_data_parallel_rank() ) return next_tokens diff --git a/paddlenlp/experimental/transformers/qwen2_moe/modeling.py b/paddlenlp/experimental/transformers/qwen2_moe/modeling.py index 9b1600fafd58..e88acc3a4c07 100644 --- a/paddlenlp/experimental/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2_moe/modeling.py @@ -109,6 +109,7 @@ def __init__(self, config: Qwen2MoeConfig): self.vocab_size, self.hidden_size, weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group(), ) else: self.embed_tokens = nn.Embedding( diff --git a/paddlenlp/trl/llm_utils.py b/paddlenlp/trl/llm_utils.py index f3971ac5271d..d20091d7bdf2 100644 --- a/paddlenlp/trl/llm_utils.py +++ b/paddlenlp/trl/llm_utils.py @@ -610,7 +610,7 @@ def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Q from paddlenlp_ops import get_output while True: - get_output(output_tensor, 0, True) + get_output(output_tensor, 0, fleet.get_hybrid_communicate_group().get_data_parallel_rank(), True) if int(output_tensor[0, 0]) == -2: # read none continue bsz = int(output_tensor[1, 0]) @@ -740,6 +740,7 @@ def init_dist_env(): hcg = fleet.get_hybrid_communicate_group() tensor_parallel_rank = hcg.get_model_parallel_rank() + tensor_parallel_degree = hcg.get_model_parallel_world_size() return tensor_parallel_rank, tensor_parallel_degree