Skip to content

Commit

Permalink
support dp
Browse files Browse the repository at this point in the history
  • Loading branch information
ckl117 committed Feb 10, 2025
1 parent 2f78730 commit a948bfe
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 16 deletions.
9 changes: 5 additions & 4 deletions csrc/gpu/get_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ struct msgdata {
};

void GetOutput(const paddle::Tensor& x,
int64_t rank_id,
int64_t tp_rank,
int64_t dp_rank,
bool wait_flag) {
if (rank_id > 0) return;
if (tp_rank > 0) return;

static struct msgdata msg_rcv;

static key_t key = ftok("./", 1);
static key_t key = ftok("./", dp_rank+1);

static int msgid = msgget(key, IPC_CREAT | 0666);

Expand Down Expand Up @@ -62,7 +63,7 @@ void GetOutput(const paddle::Tensor& x,

PD_BUILD_OP(get_output)
.Inputs({"x"})
.Attrs({"rank_id: int64_t",
.Attrs({"tp_rank: int64_t", "dp_rank: int64_t",
"wait_flag: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
Expand Down
9 changes: 5 additions & 4 deletions csrc/gpu/save_with_output_msg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ struct msgdata {

void SaveOutMmsg(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id) {
if (rank_id > 0) return;
int64_t tp_rank,
int64_t dp_rank) {
if (tp_rank > 0) return;
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t *x_data = x_cpu.data<int64_t>();
auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), false);
bool* not_need_stop_data = not_need_stop_cpu.data<bool>();

static struct msgdata msg_sed;
static key_t key = ftok("./", 1);
static key_t key = ftok("./", dp_rank+1);
static int msgid = msgget(key, IPC_CREAT | 0666);

msg_sed.mtype = 1;
Expand All @@ -54,7 +55,7 @@ void SaveOutMmsg(const paddle::Tensor& x,

PD_BUILD_OP(save_output)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t"})
.Attrs({"tp_rank: int64_t", "dp_rank: int64_t"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SaveOutMmsg));
35 changes: 31 additions & 4 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
class PredictorArgument:
model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."})
model_prefix: str = field(default="model", metadata={"help": "the prefix name of static model"})
dp_degree: int = field(default=8, metadata={"help": "The data parallel degree."})
src_length: int = field(default=1024, metadata={"help": "The max length of source text."})
min_length: int = field(default=1, metadata={"help": "the min length for decoding."})
max_length: int = field(default=1024, metadata={"help": "the max length for decoding."})
Expand Down Expand Up @@ -1428,11 +1429,14 @@ def predict():
paddle.set_device(predictor_args.device)
paddle.set_default_dtype(predictor_args.dtype)

tensor_parallel_degree = paddle.distributed.get_world_size()
world_size = paddle.distributed.get_world_size()
dp_degree = predictor_args.dp_degree
tensor_parallel_degree = world_size // dp_degree

if tensor_parallel_degree > 1:
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"dp_degree": dp_degree,
"mp_degree": tensor_parallel_degree,
"pp_degree": 1,
"sharding_degree": 1,
Expand Down Expand Up @@ -1460,14 +1464,32 @@ def predict():
target_texts.append("")

else:
source_texts = ["解释一下温故而知新"] * predictor_args.batch_size
target_texts = [""] * predictor_args.batch_size
source_texts = ["解释一下温故而知新",
"你好,你是谁",
"请问中国的首都是哪里呢?",
"请问法国的首都是哪里呢?",
"请问日本的首都是哪里呢?",
"请问南非的首都是哪里呢?",
"请问英国的首都是哪里呢?",
"小日本为什么叫小日本呢?",
]
target_texts = [""] * len(source_texts)

batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size)
batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size)

if predictor_args.dp_degree > 1:
hcg = fleet.get_hybrid_communicate_group()
dp_degree = hcg.get_data_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
else:
dp_degree = 1
dp_id = 0

with open(model_args.output_file, "w", encoding="utf-8") as f:
for bs, batch_source_text in enumerate(batch_source_texts):
if bs % dp_degree != dp_id:
continue
logger.info("Start predict")
outputs = predictor.predict(batch_source_text)
logger.info("End predict")
Expand All @@ -1487,6 +1509,11 @@ def predict():
if predictor_args.benchmark:
benchmark(predictor, predictor_args, model_args)

paddle.distributed.barrier()
import paddle.distributed as dist
data = paddle.to_tensor([1])
dist.all_reduce(data)
print(data)

def benchmark(predictor, predictor_args, model_args):
# Just construct a simple benchmark input. We pad input to the src_length.
Expand Down
12 changes: 10 additions & 2 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from paddle.distributed import fleet

import os
from dataclasses import dataclass
Expand Down Expand Up @@ -377,6 +378,13 @@ def __init__(self, config: FusedMultiTransformerConfig):
self._epsilon = config.epsilon
self._residual_alpha = config.residual_alpha
self.nranks = config.nranks

if self.nranks > 1:
dp_degree = fleet.get_hybrid_communicate_group().get_data_parallel_world_size()
self.tp_group = None
if dp_degree > 1:
self.tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()

self.norm_type = config.norm_type
if self.norm_type == "layernorm":
self.norm_func = fused_layer_norm
Expand Down Expand Up @@ -1434,7 +1442,7 @@ def forward(
)
# all_reduce
if self.nranks > 1:
dist.all_reduce(out_linear_out)
dist.all_reduce(out_linear_out, group = self.tp_group)

# ffn layernorm
tmp_out, residual_input = self.compute_ffn_layernorm(out_linear_out, residual_input, i)
Expand All @@ -1457,7 +1465,7 @@ def forward(

# all_reduce
if self.nranks > 1:
dist.all_reduce(ffn2_out)
dist.all_reduce(ffn2_out, group = self.tp_group)

# norm + residual_add_bias
tmp_out, residual_input = self.compute_bias_residual_layernorm(
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import paddle
import paddle.nn.functional as F
from paddle.distributed import fleet

from paddlenlp.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList

Expand Down Expand Up @@ -733,7 +734,8 @@ def _post_process_(
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)

if self.config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(next_tokens, 0)
rank = fleet.get_hybrid_communicate_group().get_data_parallel_rank() * self.config.tensor_parallel_degree
paddle.distributed.broadcast(next_tokens, rank, group = fleet.get_hybrid_communicate_group().get_model_parallel_group())

with paddle.base.framework._stride_in_no_check_dy2st_diff():
from paddlenlp_ops import update_inputs_v2
Expand All @@ -760,6 +762,7 @@ def _post_process_(
next_tokens,
model_kwargs["not_need_stop"],
self.config.tensor_parallel_rank,
fleet.get_hybrid_communicate_group().get_data_parallel_rank()
)
return next_tokens

Expand Down
1 change: 1 addition & 0 deletions paddlenlp/experimental/transformers/qwen2_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self, config: Qwen2MoeConfig):
self.vocab_size,
self.hidden_size,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group(),
)
else:
self.embed_tokens = nn.Embedding(
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/trl/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Q
from paddlenlp_ops import get_output

while True:
get_output(output_tensor, 0, True)
get_output(output_tensor, 0, fleet.get_hybrid_communicate_group().get_data_parallel_rank(), True)
if int(output_tensor[0, 0]) == -2: # read none
continue
bsz = int(output_tensor[1, 0])
Expand Down Expand Up @@ -740,6 +740,7 @@ def init_dist_env():
hcg = fleet.get_hybrid_communicate_group()

tensor_parallel_rank = hcg.get_model_parallel_rank()
tensor_parallel_degree = hcg.get_model_parallel_world_size()
return tensor_parallel_rank, tensor_parallel_degree


Expand Down

0 comments on commit a948bfe

Please sign in to comment.