From 9414d67335de6839a2cfc85d1324111c618091a5 Mon Sep 17 00:00:00 2001 From: drownfish19 Date: Mon, 30 Sep 2024 09:44:16 +0000 Subject: [PATCH 01/20] add expert parallel utils --- paddlenlp/transformers/moe_gate.py | 554 ++++++++++++++++++++++++++++ paddlenlp/transformers/moe_layer.py | 276 ++++++++++++++ 2 files changed, 830 insertions(+) create mode 100644 paddlenlp/transformers/moe_gate.py create mode 100644 paddlenlp/transformers/moe_layer.py diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py new file mode 100644 index 000000000000..c6fbbe6fa8d3 --- /dev/null +++ b/paddlenlp/transformers/moe_gate.py @@ -0,0 +1,554 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import paddle +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F + +from ..utils.log import logger + + +@paddle.no_grad() +def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): + """ + Computes the optimal transport matrix and Slinkhorn distance using the + Sinkhorn-Knopp algorithm + + Inputs: + - M : cost matrix (n x m) + - r : vector of marginals (n, ) + - c : vector of marginals (m, ) + - lam : strength of the entropic regularization + - epsilon : convergence parameter + + Outputs: + - P : optimal transport matrix (n x m) + - dist : Sinkhorn distance + """ + n, _ = M.shape + # P = (- lam * M).exp() + # P /= P.sum() + P = F.softmax(-M / lam) + u = paddle.zeros(n, "float32") + # normalize this matrix + for _ in range(max_iters): + if (u - P.sum(1)).abs().max() < epsilon: + break + u = P.sum(1) + P *= (r / (u + 1e-8)).reshape((-1, 1)) + P *= (c / (P.sum(0) + 1e-8)).reshape((1, -1)) + P = paddle.where(~P.isnan(), P, paddle.zeros_like(P)) + return P, _ + + +class BaseGate(nn.Layer): + def __init__( + self, + num_experts, + expert_hidden_size, + weight_attr=None, + bias_attr=None, + **kwargs, + ): + super(BaseGate, self).__init__() + + self.num_experts = num_experts + self.expert_hidden_size = expert_hidden_size + + # force keep in float32 when using amp + self._cast_to_low_precision = False + self._weight_attr = weight_attr + self._bias_attr = bias_attr + + self.weight = paddle.create_parameter( + shape=[self.expert_hidden_size, self.num_experts], + attr=self._weight_attr, + dtype="float32", + is_bias=False, + ) + self.bias = paddle.create_parameter( + shape=[self.num_experts], + attr=self._bias_attr, + dtype="float32", + is_bias=True, + ) + + self.group = getattr(kwargs, "group", None) + self.global_aux_loss = getattr(kwargs, "global_aux_loss", False) + if self.global_aux_loss: + assert self.group is not None, "group is required when global_aux_loss is True" + self.rank = dist.get_rank(self.group) + + self.expert_drop = getattr(kwargs, "expert_drop", False) + + def gate_score_func(self, logits): + # [..., hidden_dim] -> [..., num_experts] + with paddle.amp.auto_cast(False): + scoring_func = getattr(self, "scoring_func", None) + if scoring_func == "softmax": + scores = F.softmax(logits.astype("float32"), axis=-1) + elif scoring_func == "sigmoid": + scores = F.sigmoid(logits) + elif scoring_func == "tanh": + scores = F.tanh(logits) + elif scoring_func == "relu": + scores = F.relu(logits) + elif scoring_func == "gelu": + scores = F.gelu(logits) + elif scoring_func == "leaky_relu": + scores = F.leaky_relu(logits) + else: + logger.warning(f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead") + scores = F.softmax(logits.astype("float32"), axis=-1) + return scores + + def scaling_weight(self, weight: paddle.Tensor): + topk = getattr(self, "topk", 1) + scaling_attr = getattr(self, "scaling_attr", None) + + if topk > 1 and isinstance(scaling_attr, bool) and scaling_attr: + # if scaling is a bool, it means that scaling with the weight + scaling_factor = 1 / (weight.sum(axis=-1, keepdim=True) + 1e-20) + elif isinstance(scaling_attr, (int, float)): + scaling_factor = float(scaling_attr) + else: + logger.warning_once(f"scaling_attr is not set, use the default value 1.0") + scaling_factor = 1.0 + + return weight * scaling_factor + + def gumbel_rsample(self, logits: paddle.Tensor) -> paddle.Tensor: + gumbel = paddle.distribution.gumbel.Gumbel(0, 1) + return gumbel.rsample(logits.shape) + + def uniform_sample(self, logits: paddle.Tensor) -> paddle.Tensor: + uniform = paddle.distribution.uniform.Uniform(0, 1) + return uniform.sample(logits.shape) + + def _cal_aux_loss(self, gates, mask): + """ + 计算辅助损失 + + Args: + gates (paddle.Tensor): 表示每个expert的输出概率。形状为[batch_size,num_experts] + mask (paddle.Tensor): 表示每个样本是否属于某个expert。形状为[batch_size,num_experts] + + Returns: + paddle.Tensor: 辅助损失值。 + + """ + me = paddle.mean(gates, axis=0) + ce = paddle.mean(mask.cast("float32"), axis=0) + if self.global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=self.group) + dist.all_gather(ce_list, ce, group=self.group) + + me_list[self.rank] = me + ce_list[self.rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + aux_loss = paddle.sum(me * ce) * float(self.num_experts) + return aux_loss + + def _cal_z_loss(self, logits) -> paddle.Tensor: + """ + 计算z损失 + Args: + logits (paddle.paddle.Tensor): 模型输出。形状为[batch_size, num_experts] + Returns: + paddle.paddle.Tensor: z损失值。 + """ + l_zloss = logits.exp().sum(1).log().square().mean() + return l_zloss + + def _cal_orthogonal_loss(self) -> paddle.Tensor: + """Gate weight orthogonal loss. + + Returns: + Paddle.Tensor: orthogonal loss + """ + weight = F.normalize(self.weight, axis=0) + orthogonal_loss = paddle.mean(paddle.square(paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts))) + return orthogonal_loss + + @paddle.no_grad() + def _capacity(self, gates: paddle.Tensor, capacity_factor: float, min_capacity: int) -> paddle.Tensor: + """Calculate the capacity for each expert based on the gates and capacity factor. + + Args: + gates (paddle.Tensor): A tensor of shape [num_tokens, num_experts] representing the probability distribution + over experts for each token. + capacity_factor (float): A scalar float value representing the capacity factor for each expert. + min_capacity (int): A scalar integer value representing the minimum capacity for each expert. + + Returns: + int: A tensor value representing the calculated capacity for each expert. + """ + assert gates.ndim == 2, f"gates should be 2D, but got {gates.ndim}, {gates.shape}" + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + capacity = int((num_tokens // num_experts) * capacity_factor) + if capacity < min_capacity: + capacity = min_capacity + assert capacity > 0, f"requires capacity > 0, capacity_factor: {capacity_factor}, input_shape: {gates.shape}" + + return capacity + + @paddle.no_grad() + def _one_hot_to_float(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.float32) + + @paddle.no_grad() + def _one_hot_to_int64(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.int64) + + def top1gating( + self, + logits: paddle.Tensor, + capacity_factor: float, + min_capacity: int, + used_token: paddle.Tensor = None, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements Top1Gating on logits.""" + if noisy_gate_policy == "RSample": + logits += self.gumbel_rsample(logits.shape) + + gates = self.gate_score_func(logits=logits) + capacity = self._capacity(gates, capacity_factor, min_capacity) + + # Create a mask for 1st's expert per token + # noisy gating + indices1_s = paddle.argmax(logits if noisy_gate_policy == "RSample" else gates, axis=1) # 仅保存最大值位置 + mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) # 将最大值位置转换为one-hot向量 [s, e] + + # mask only used tokens + if used_token is not None: + mask1 = paddle.einsum("s,se->se", used_token, mask1) # 将used_token与mask1进行逐元素相乘,得到新的mask1 + + # gating decisions + exp_counts = paddle.sum(mask1, axis=0) # 计算每个专家的token数量 + + # if we don't want to drop any tokens + if not drop_tokens: + new_capacity = paddle.max(exp_counts) # 计算每个专家的token数量 + # Communicate across expert processes to pick the maximum capacity. + if self.group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) # 在专家进程之间进行最大值计算 + # Make sure the capacity value does not exceed the number of tokens. + capacity = int(min(new_capacity, paddle.tensor(mask1.size(0)))) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # Random Token Selection + if use_rts: + mask1_rand = mask1 * self.uniform_sample(mask1) + else: + mask1_rand = mask1 + + assert ( + logits.shape[0] >= min_capacity + ), "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." + + _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # 选择top_capacity个token + + # 将mask1中的元素与top_idx进行逐元素相乘,得到新的mask1 + new_mask1 = mask1 * paddle.zeros_like(mask1).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=0) + mask1 = new_mask1 + + # Compute locations in capacity buffer + locations1 = paddle.cumsum(mask1, axis=0) - 1 # 计算每个token在mask1中的位置 + + # Store the capacity location for each token + locations1_s = paddle.sum(locations1 * mask1, axis=1).cast(paddle.int64) # 计算每个token在mask1中的位置 + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + gates = gates / gates * mask1_float + + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + combine_weights = paddle.einsum("se,sc->sec", gates, locations1_sc) + dispatch_mask = combine_weights.cast(paddle.bool).detach() + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def top2gating( + self, + logits: paddle.Tensor, + capacity_factor: float, + min_capacity: int, + drop_tokens: bool = True, + top2_2nd_expert_sampling: bool = True, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """ + Args: + logits: [S, E],形状为 [seq_len, num_experts],用于计算top2 gate。 + cap: 表示每个token可以分发的最大数量的超参数。 + + Returns: + tuple: + - capacity: 每个token可分发的最大数量。 + - dispatch_masks: 用于dispatching的mask。第一个元素是第一类token的mask;第二个元素是第二类token的mask。 + - combine_weights:用于combining的权重。第一个元素是第一类token的权重;第二个元素是第二类token的权重。 + - scatter_indexes: 用于scattering的索引。第一个元素是第一类token的索引;第二个元素是第二类token的索引。 + - loss_aux: aux loss。 + - loss_z: z loss。 + """ + """Implements Top2Gating on logits.""" + # everything is in fp32 in this function + gates = self.gate_score_func(logits=logits) + + # Create a mask for 1st's expert per token. + indices1_s = paddle.argmax(gates, axis=1) # [S, 1] + mask1 = self._one_hot_to_int64(indices1_s, self.num_experts) # [S, E] + + if top2_2nd_expert_sampling: + # Create a mask for 2nd's expert per token using Gumbel-max trick. + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits += self.gumbel_rsample(logits) + + # Replace top-expert with min value + logits_except1 = logits.masked_fill(mask1.cast(paddle.bool), float("-inf")) # [S, E] + indices2_s = paddle.argmax(logits_except1, axis=1) # [S, 1] + mask2 = self._one_hot_to_int64(indices2_s, self.num_experts) # [S, E] + + # Note: mask1 and mask2 can be combined to form a single mask. + # mask = paddle.concat([mask1, mask2], axis=0) + # locations = paddle.cumsum(mask, axis=0) - 1 + # locations1, locations2 = locations.split(2, axis=0) + # Compute locations in capacity buffer. + locations1 = paddle.cumsum(mask1, axis=0) - 1 # [S, E] + locations2 = paddle.cumsum(mask2, axis=0) - 1 # [S, E] + # Update 2nd's location by accounting for locations of 1st. + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # gating decisions + exp_counts = paddle.sum(mask1 + mask2, axis=0) + if drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity(gates, capacity_factor, min_capacity) + # Remove locations outside capacity from mask. + mask1 *= (locations1 < capacity).cast(paddle.int64) + mask2 *= (locations2 < capacity).cast(paddle.int64) + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(new_capacity) + + # Store the capacity location for each token. + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = paddle.einsum("se,se->s", gates, mask1_float) + gates2_s = paddle.einsum("se,se->s", gates, mask2_float) + denom_s = gates1_s + gates2_s + # Avoid divide-by-zero + denom_s = paddle.clip(denom_s, min=paddle.finfo(denom_s.dtype).eps) + gates1_s /= denom_s + gates2_s /= denom_s + + # Calculate combine_weights and dispatch_mask + gates1 = paddle.einsum("s,se->se", gates1_s, mask1_float) + gates2 = paddle.einsum("s,se->se", gates2_s, mask2_float) + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + locations2_sc = self._one_hot_to_float(locations2_s, capacity) + combine1_sec = paddle.einsum("se,sc->sec", gates1, locations1_sc) + combine2_sec = paddle.einsum("se,sc->sec", gates2, locations2_sc) + combine_weights = combine1_sec + combine2_sec + dispatch_mask = combine_weights.cast(paddle.bool) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def topkgating( + self, + logits: paddle.Tensor, + k: int, + capacity_factor: float, + min_capacity: int, + drop_tokens: bool = True, + drop_policy: str = "probs", + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements TopKGating on logits.""" + + # everything is in fp32 in this function + # get topk gates + top_gate, top_idx = paddle.topk(logits, k=k, axis=1) + # gating decisions + gates = self.gate_score_func(logits=logits) + # get topk mask + topk_masked_gates = paddle.zeros_like(logits).put_along_axis(top_idx, top_gate, axis=1) + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + exp_counts = paddle.sum(mask, axis=0) + + l_aux = self._cal_aux_loss(gates, mask) + l_zloss = self._cal_z_loss(logits) + + if drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity(gates, capacity_factor * k, min_capacity) + + # update mask and locations by capacity + if drop_policy == "probs": + capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) + capacity_mask = paddle.zeros_like(logits).put_along_axis( + capacity_indices, paddle.to_tensor(1.0), axis=0 + ) + mask = mask * capacity_mask + locations = paddle.cumsum(mask, axis=0) - 1 + + elif drop_policy == "position": + locations = paddle.cumsum(mask, axis=0) - 1 + mask *= (locations < capacity).cast(paddle.int64) + else: + raise ValueError(f"Invalid drop_policy: {drop_policy}") + + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(new_capacity) + + # normalize gates + gates_masked = gates * mask + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + gates_masked = gates_masked / denom_s + + # dispatch_mask + locations_sc = self._one_hot_to_float(locations * mask, num_classes=capacity) + combine_weights = paddle.einsum("se,sec->sec", gates_masked, locations_sc) + dispatch_mask = combine_weights.cast(paddle.bool) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def forward(self, hidden_states): + raise NotImplementedError("Please implement the forward function.") + + +class TopKGate(BaseGate): + def __init__( + self, + num_experts, + expert_hidden_size, + weight_attr=None, + bias_attr=None, + topk=2, + scoring_func="softmax", + scaling_attr=None, + ): + super().__init__(num_experts, expert_hidden_size, weight_attr, bias_attr) + self.topk = topk + self.scoring_func = scoring_func + self.scaling_attr = scaling_attr + + def forward( + self, + hidden_states: paddle.Tensor, + used_token: paddle.Tensor = None, + ): + bsz, seq_len, hidden_size = hidden_states.shape + hidden_states = hidden_states.reshape([-1, hidden_size]) + logits = F.linear(x=paddle.cast(hidden_states, paddle.float32), weight=self.weight, bias=self.bias) + if self.topk == 1: + gate_output = self.top1gating( + logits, + self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, + used_token, + self.noisy_gate_policy if self.training else None, + self.drop_tokens, + ) + elif self.topk == 2: + gate_output = self.top2gating( + logits, + self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, + self.drop_tokens, + self.top2_2nd_expert_sampling, + ) + else: + gate_output = self.topkgating( + logits, + self.topk, + self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, + self.drop_tokens, + ) + + return gate_output + + +class GroupTopKGate(BaseGate): + def __init__( + self, + num_experts, + expert_hidden_size, + weight_attr=None, + bias_attr=None, + topk=2, + scoring_func="softmax", + scaling_attr=None, + n_group=1, + topk_group=1, + ): + super().__init__(num_experts, expert_hidden_size, weight_attr, bias_attr) + self.topk = topk + self.scoring_func = scoring_func + self.scaling_attr = scaling_attr + self.n_group = n_group + self.topk_group = topk_group + + def forward(self, hidden_states): + bsz, seq_len, hidden_size = hidden_states.shape + hidden_states = hidden_states.reshape([-1, hidden_size]) + scores = self.gate_score_func(hidden_states) + + group_scores = scores.reshape([bsz * seq_len, self.n_group, -1]).max(axis=-1).values # [n, n_group] + group_idx = paddle.topk(group_scores, k=self.topk_group, axis=-1, sorted=False)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(bsz * seq_len, self.n_group, self.num_experts // self.n_group) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + + topk_weight, topk_idx = paddle.topk(tmp_scores, k=self.topk, axis=-1, largest=True, sorted=False) + + if self.scaling_attr is not None: + topk_weight = self.scaling_weight(topk_weight) + + return topk_weight, topk_idx, scores diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py new file mode 100644 index 000000000000..7edd398e7b44 --- /dev/null +++ b/paddlenlp/transformers/moe_layer.py @@ -0,0 +1,276 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from contextlib import contextmanager +from typing import Any, List, Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.communication import stream +from paddle.distributed.communication.group import Group +from paddle.distributed.fleet.utils import recompute + +from ..utils.log import logger + +GateOutput = namedtuple( + "GateOutput", + [ + "aux", + "z", + "logits", + ], +) + + +def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): + """ + Rearranges the input tensor `x` based on gate results, truncates it according to the specified capacity, and performs padding. + + Args: + x (Tensor)[Seq, Dim]: The input tensor. + dispatch_mask (List[Tensor[Seq, 1], Tensor[Seq, 1]]): A list of dispatch masks. + scatter_index (Union[List[Tensor[Seq,], Tensor[Seq]], Tensor[Seq, 2]]): A list or tensor representing scatter indices. + num_experts (int): The number of experts. + capacity (int): The capacity size. + + Returns: + Tensor [Expert*Capacity, Dim]: The output tensor after dispatching. + """ + output = None + orig_dtype = x.dtype + if isinstance(scatter_index, paddle.Tensor): + scatter_index = scatter_index.unbind(1) + for i_scatter_index, i_dispatch_mask in zip(scatter_index, dispatch_mask): + init_output = paddle.zeros([num_experts * capacity, x.shape[-1]], dtype="float32") + updates = x * i_dispatch_mask.cast(x.dtype) + if output is None: + output = paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + else: + output = output + paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + if output.dtype != orig_dtype: + output = output.cast(orig_dtype) + return output + + +def combining(x, combine_weights, scatter_index): + """ + Performs combination and aggregation operations on the input matrix. + + Args: + x: Tensor[num_experts * capacity, dim] - The input matrix to be processed, where the last dimension represents the number of features. + combine_weights: Union[List[Tensor[seq, 1], Tensor[seq, 1]], Tensor[seq, 2, 1]] - A list or tensor containing combination weights for each feature. + scatter_index: Union[List[Tensor[seq], Tensor[seq]], Tensor[seq, 2]] - A tuple of indices indicating which elements are to be aggregated, where the first element is the row index and the second element is the column index. + + Returns: + Tensor: The output matrix after combination and aggregation, with a shape of [n, dim * num_features], where n is the number of samples in the input matrix. + """ + + dim = x.shape[-1] + if isinstance(scatter_index, (list, tuple)): + scatter_index = paddle.concat([i.unsqueeze([-1]) for i in scatter_index], -1) + scatter_index = scatter_index.reshape([-1]) + num_k = len(combine_weights) if isinstance(combine_weights, (list, tuple)) else combine_weights.shape[-1] + x = paddle.gather(x, scatter_index).reshape([-1, num_k, dim]) # [seq,2,dim] + if isinstance(combine_weights, (list, tuple)): + combine_weights = paddle.concat(combine_weights, -1).unsqueeze([1]) + return paddle.matmul(combine_weights, x).squeeze(1) # [seq,1,2] @ [seq,2,dim] -> [seq,1,dim] + + +class _AllToAll(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx: Any, + input: Tensor, + group: Group, + ) -> Tensor: # type: ignore + """ + All-to-all communication in the group. + + Args: + ctx (Any): Context object. + input (Tensor): Input tensor. + group (Group): The group object. + + Returns: + Tensor: Output tensor. + """ + + ctx.group = group + # return input + if dist.get_world_size(group) <= 1: + return input + output = paddle.empty_like(input) + stream.alltoall_single(output, input, None, None, group, True, True) + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]: + """ + Aggregates gradient information from all input tensors into a single tensor. + + Args: + ctx (Any): The context object used to store information that needs to be passed. + *grad_output (Tensor): A list of input tensors whose gradients are to be aggregated. + + Returns: + Tuple[Tensor]: A tuple containing a tensor that holds the gradients of all input tensors. + + """ + # return grad_output + return _AllToAll.apply(*grad_output, ctx.group) + + +class MoELayer(nn.Layer): + """MOELayer module which implements MixtureOfExperts as described in Gshard_. + :: + + gate = Top2Gate(model_dim, num_experts) + + moe = MoELayer(gate, expert) + output = moe(input) + l_aux = moe.l_aux + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + gate (paddle.nn.Layer): + gate network + expert (paddle.nn.LayerList): + expert network, LayerList 长度是 per_device 上的 expert 数。 + group (paddle.ProgressGroup) + recompute: 启用MOE内recomupte + Returns: + output + combine_weight + router-loss + """ + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + group: Group = None, + recompute=False, + all_to_all_dropout=0.0, + moe_num_experts=2, + ): + super().__init__() + self.gate = gate + self.layer_idx = layer_idx + self.recompute = recompute + logger.info(f"using moe recompute={recompute}") + for p in self.gate.parameters(): + p.is_gate = True + if type(experts) == nn.LayerList: + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = nn.LayerList([experts]) + self.group = group + self.all_to_all_dropout = all_to_all_dropout + is_dummy_moe = dist.get_world_size(group) == 1 or dist.get_world_size(group) == -1 + + for k in experts: + if k is not None: + for p in k.parameters(): + p.expert = not is_dummy_moe + p.no_sync = not is_dummy_moe + # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") + + self.world_size = dist.get_world_size(group) + self.rank = dist.get_rank(self.group) + if self.world_size < 1: + self.world_size = 1 + if self.rank < 0: + self.rank = 0 + + self.num_local_experts = moe_num_experts // self.world_size + + def forward(self, input): + true_experts = self.experts[self.rank * self.num_local_experts : (self.rank + 1) * self.num_local_experts] + if input.ndim == 3: + orig_shape = input.shape + reshaped_input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + assert len(input.shape) == 2, f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + + # Implement Algorithm 2 from GShard paper. + seqlen, d_model = input.shape + + # Reshape into S tokens by dropping sequence dimension. + # reshaped_input = input.reshape(-1, d_model) + # assert reshaped_input.shape[0] % len(self.experts) == 0, \ + # f'num tokens must be order of number of local experts, {input[0].shape[0]} vs {len(self.experts)}' + def fwdfn(dispatched_input): + expert_outputs = [] + chunks = dispatched_input.unbind(1) + assert len(chunks) == len(true_experts), (len(chunks), len(true_experts)) + for chunk, expert in zip(chunks, true_experts): + chunk = chunk.contiguous() + expert_outputs += [expert(chunk)] + expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + return expert_output + + assert self.gate is not None + if hasattr(self, "rng") and self.rng.random() < self.all_to_all_dropout: + orig_shape_2 = input.shape + input = input.reshape([self.world_size, self.num_local_experts, -1, input.shape[-1]]) + output = fwdfn(input) + output += self.gate.weight.sum() * 0.0 # hack for grad + output = output.reshape(orig_shape or orig_shape_2) # [e*1,c,m] + return output, None, 0 + + capacity, dispatch_mask, combine_weights, scatter_index, router_loss = self.gate(input) + self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) + + if self.world_size > 1: + dispatched_input = _AllToAll.apply(dispatched_input, self.group) + dispatched_input = dispatched_input.reshape([self.world_size * self.num_local_experts, capacity, d_model]) + expert_output = ( + recompute(fwdfn, dispatched_input) if self.recompute and self.training else fwdfn(dispatched_input) + ) + d_model_out = expert_output.shape[-1] + + if self.world_size > 1: + expert_output = _AllToAll.apply(expert_output, self.group) # 拿到不同device上的expert计算结果 + + expert_output = expert_output.reshape( + [self.world_size * self.num_local_experts * capacity, d_model_out] + ) # [e * 1, c, m] + combined_output = combining(expert_output, combine_weights, scatter_index) + + if orig_shape: + combined_output = combined_output.reshape( + orig_shape[:-1] + + [ + d_model_out, + ] + ) + return combined_output, combine_weights, router_loss From a6260cff497cb6820c4a5e185b180cfb0d14b1ee Mon Sep 17 00:00:00 2001 From: drownfish19 Date: Wed, 9 Oct 2024 02:12:41 +0000 Subject: [PATCH 02/20] update gates --- paddlenlp/transformers/moe_gate.py | 327 ++++++++++------------------- 1 file changed, 114 insertions(+), 213 deletions(-) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index c6fbbe6fa8d3..3642a64af6f3 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Tuple import paddle import paddle.distributed as dist @@ -23,115 +23,28 @@ from ..utils.log import logger -@paddle.no_grad() -def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): - """ - Computes the optimal transport matrix and Slinkhorn distance using the - Sinkhorn-Knopp algorithm - - Inputs: - - M : cost matrix (n x m) - - r : vector of marginals (n, ) - - c : vector of marginals (m, ) - - lam : strength of the entropic regularization - - epsilon : convergence parameter - - Outputs: - - P : optimal transport matrix (n x m) - - dist : Sinkhorn distance - """ - n, _ = M.shape - # P = (- lam * M).exp() - # P /= P.sum() - P = F.softmax(-M / lam) - u = paddle.zeros(n, "float32") - # normalize this matrix - for _ in range(max_iters): - if (u - P.sum(1)).abs().max() < epsilon: - break - u = P.sum(1) - P *= (r / (u + 1e-8)).reshape((-1, 1)) - P *= (c / (P.sum(0) + 1e-8)).reshape((1, -1)) - P = paddle.where(~P.isnan(), P, paddle.zeros_like(P)) - return P, _ - - -class BaseGate(nn.Layer): - def __init__( - self, - num_experts, - expert_hidden_size, - weight_attr=None, - bias_attr=None, - **kwargs, - ): - super(BaseGate, self).__init__() - - self.num_experts = num_experts - self.expert_hidden_size = expert_hidden_size - - # force keep in float32 when using amp - self._cast_to_low_precision = False - self._weight_attr = weight_attr - self._bias_attr = bias_attr - - self.weight = paddle.create_parameter( - shape=[self.expert_hidden_size, self.num_experts], - attr=self._weight_attr, - dtype="float32", - is_bias=False, - ) - self.bias = paddle.create_parameter( - shape=[self.num_experts], - attr=self._bias_attr, - dtype="float32", - is_bias=True, - ) - - self.group = getattr(kwargs, "group", None) - self.global_aux_loss = getattr(kwargs, "global_aux_loss", False) - if self.global_aux_loss: - assert self.group is not None, "group is required when global_aux_loss is True" - self.rank = dist.get_rank(self.group) - - self.expert_drop = getattr(kwargs, "expert_drop", False) - - def gate_score_func(self, logits): +class MoEGateMixin: + def gate_score_func(self, logits: paddle.Tensor) -> paddle.Tensor: # [..., hidden_dim] -> [..., num_experts] with paddle.amp.auto_cast(False): scoring_func = getattr(self, "scoring_func", None) if scoring_func == "softmax": - scores = F.softmax(logits.astype("float32"), axis=-1) + scores = F.softmax(logits.cast("float32"), axis=-1) elif scoring_func == "sigmoid": - scores = F.sigmoid(logits) + scores = F.sigmoid(logits.cast("float32")) elif scoring_func == "tanh": - scores = F.tanh(logits) + scores = F.tanh(logits.cast("float32")) elif scoring_func == "relu": - scores = F.relu(logits) + scores = F.relu(logits.cast("float32")) elif scoring_func == "gelu": - scores = F.gelu(logits) + scores = F.gelu(logits.cast("float32")) elif scoring_func == "leaky_relu": - scores = F.leaky_relu(logits) + scores = F.leaky_relu(logits.cast("float32")) else: logger.warning(f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead") - scores = F.softmax(logits.astype("float32"), axis=-1) + scores = F.softmax(logits.cast("float32"), axis=-1) return scores - def scaling_weight(self, weight: paddle.Tensor): - topk = getattr(self, "topk", 1) - scaling_attr = getattr(self, "scaling_attr", None) - - if topk > 1 and isinstance(scaling_attr, bool) and scaling_attr: - # if scaling is a bool, it means that scaling with the weight - scaling_factor = 1 / (weight.sum(axis=-1, keepdim=True) + 1e-20) - elif isinstance(scaling_attr, (int, float)): - scaling_factor = float(scaling_attr) - else: - logger.warning_once(f"scaling_attr is not set, use the default value 1.0") - scaling_factor = 1.0 - - return weight * scaling_factor - def gumbel_rsample(self, logits: paddle.Tensor) -> paddle.Tensor: gumbel = paddle.distribution.gumbel.Gumbel(0, 1) return gumbel.rsample(logits.shape) @@ -140,6 +53,42 @@ def uniform_sample(self, logits: paddle.Tensor) -> paddle.Tensor: uniform = paddle.distribution.uniform.Uniform(0, 1) return uniform.sample(logits.shape) + @paddle.no_grad() + def _one_hot_to_float(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.float32) + + @paddle.no_grad() + def _one_hot_to_int64(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.int64) + + @paddle.no_grad() + def _capacity(self, gates: paddle.Tensor, capacity_factor: float, min_capacity: int) -> paddle.Tensor: + """Calculate the capacity for each expert based on the gates and capacity factor. + + Args: + gates (paddle.Tensor): A tensor of shape [num_tokens, num_experts] representing the probability distribution + over experts for each token. + capacity_factor (float): A scalar float value representing the capacity factor for each expert. + min_capacity (int): A scalar integer value representing the minimum capacity for each expert. + + Returns: + int: A tensor value representing the calculated capacity for each expert. + """ + assert gates.ndim == 2, f"gates should be 2D, but got {gates.ndim}, {gates.shape}" + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + capacity = int((num_tokens // num_experts) * capacity_factor) + if capacity < min_capacity: + capacity = min_capacity + assert capacity > 0, f"requires capacity > 0, capacity_factor: {capacity_factor}, input_shape: {gates.shape}" + + return capacity + def _cal_aux_loss(self, gates, mask): """ 计算辅助损失 @@ -187,62 +136,68 @@ def _cal_orthogonal_loss(self) -> paddle.Tensor: orthogonal_loss = paddle.mean(paddle.square(paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts))) return orthogonal_loss - @paddle.no_grad() - def _capacity(self, gates: paddle.Tensor, capacity_factor: float, min_capacity: int) -> paddle.Tensor: - """Calculate the capacity for each expert based on the gates and capacity factor. - Args: - gates (paddle.Tensor): A tensor of shape [num_tokens, num_experts] representing the probability distribution - over experts for each token. - capacity_factor (float): A scalar float value representing the capacity factor for each expert. - min_capacity (int): A scalar integer value representing the minimum capacity for each expert. +class PretrainedMoEGate(nn.Layer, MoEGateMixin): + def __init__(self, num_experts, expert_hidden_size, **kwargs): + super(PretrainedMoEGate, self).__init__() - Returns: - int: A tensor value representing the calculated capacity for each expert. - """ - assert gates.ndim == 2, f"gates should be 2D, but got {gates.ndim}, {gates.shape}" - # gates has shape of SE - num_tokens = gates.shape[0] - num_experts = gates.shape[1] - capacity = int((num_tokens // num_experts) * capacity_factor) - if capacity < min_capacity: - capacity = min_capacity - assert capacity > 0, f"requires capacity > 0, capacity_factor: {capacity_factor}, input_shape: {gates.shape}" + self.num_experts = num_experts + self.expert_hidden_size = expert_hidden_size - return capacity + self.capacity_factor = kwargs["capacity_factor"] if hasattr(kwargs, "capacity_factor") else 1.0 # fmt:skip + self.eval_capacity_factor = kwargs["eval_capacity_factor"] if hasattr(kwargs, "eval_capacity_factor") else 1.0 # fmt:skip + self.min_capacity = kwargs["min_capacity"] if hasattr(kwargs, "min_capacity") else 1.0 # fmt:skip - @paddle.no_grad() - def _one_hot_to_float(self, x, num_classes): - if x.dtype not in (paddle.int32, paddle.int64): - x = paddle.cast(x, paddle.int64) - return F.one_hot(x, num_classes=num_classes).cast(paddle.float32) + # force keep in float32 when using amp + self._cast_to_low_precision = False + self._weight_attr = kwargs["weight_attr"] if hasattr(kwargs, "weight_attr") else None # fmt:skip + self._bias_attr = kwargs["bias_attr"] if hasattr(kwargs, "bias_attr") else None # fmt:skip - @paddle.no_grad() - def _one_hot_to_int64(self, x, num_classes): - if x.dtype not in (paddle.int32, paddle.int64): - x = paddle.cast(x, paddle.int64) - return F.one_hot(x, num_classes=num_classes).cast(paddle.int64) + self.weight = paddle.create_parameter( + shape=[self.expert_hidden_size, self.num_experts], + attr=self._weight_attr, + dtype="float32", + is_bias=False, + ) + if self._bias_attr is not None and self._bias_attr: + self.bias = paddle.create_parameter( + shape=[self.num_experts], + attr=self._bias_attr, + dtype="float32", + is_bias=True, + ) + + self.group = kwargs["group"] if hasattr(kwargs, "group") else None + self.global_aux_loss = kwargs["global_aux_loss"] if hasattr(kwargs, "global_aux_loss") else False + if self.global_aux_loss: + assert self.group is not None, "group is required when global_aux_loss is True" + self.rank = dist.get_rank(self.group) + + self.expert_drop = kwargs["expert_drop"] if hasattr(kwargs, "expert_drop") else False + self.noisy_gate_policy = kwargs["noisy_gate_policy"] if hasattr(kwargs, "noisy_gate_policy") else None + self.drop_tokens = kwargs["drop_tokens"] if hasattr(kwargs, "drop_tokens") else True + self.use_rts = kwargs["use_rts"] if hasattr(kwargs, "use_rts") else True + self.top2_2nd_expert_sampling = ( + kwargs["top2_2nd_expert_sampling"] if hasattr(kwargs, "top2_2nd_expert_sampling") else True + ) + self.drop_policy = kwargs["drop_policy"] if hasattr(kwargs, "drop_policy") else "probs" + self.top_k = kwargs["top_k"] if hasattr(kwargs, "top_k") else 1 def top1gating( self, logits: paddle.Tensor, - capacity_factor: float, - min_capacity: int, used_token: paddle.Tensor = None, - noisy_gate_policy: Optional[str] = None, - drop_tokens: bool = True, - use_rts: bool = True, ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Implements Top1Gating on logits.""" - if noisy_gate_policy == "RSample": + if self.noisy_gate_policy == "RSample": logits += self.gumbel_rsample(logits.shape) gates = self.gate_score_func(logits=logits) - capacity = self._capacity(gates, capacity_factor, min_capacity) + capacity = self._capacity(gates, self.capacity_factor, self.min_capacity) # Create a mask for 1st's expert per token # noisy gating - indices1_s = paddle.argmax(logits if noisy_gate_policy == "RSample" else gates, axis=1) # 仅保存最大值位置 + indices1_s = paddle.argmax(logits if self.noisy_gate_policy == "RSample" else gates, axis=1) # 仅保存最大值位置 mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) # 将最大值位置转换为one-hot向量 [s, e] # mask only used tokens @@ -253,7 +208,7 @@ def top1gating( exp_counts = paddle.sum(mask1, axis=0) # 计算每个专家的token数量 # if we don't want to drop any tokens - if not drop_tokens: + if not self.drop_tokens: new_capacity = paddle.max(exp_counts) # 计算每个专家的token数量 # Communicate across expert processes to pick the maximum capacity. if self.group is not None: @@ -265,13 +220,13 @@ def top1gating( l_zloss = self._cal_z_loss(logits) # Random Token Selection - if use_rts: + if self.use_rts: mask1_rand = mask1 * self.uniform_sample(mask1) else: mask1_rand = mask1 assert ( - logits.shape[0] >= min_capacity + logits.shape[0] >= self.min_capacity ), "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # 选择top_capacity个token @@ -299,10 +254,6 @@ def top1gating( def top2gating( self, logits: paddle.Tensor, - capacity_factor: float, - min_capacity: int, - drop_tokens: bool = True, - top2_2nd_expert_sampling: bool = True, ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Args: @@ -312,9 +263,9 @@ def top2gating( Returns: tuple: - capacity: 每个token可分发的最大数量。 - - dispatch_masks: 用于dispatching的mask。第一个元素是第一类token的mask;第二个元素是第二类token的mask。 - - combine_weights:用于combining的权重。第一个元素是第一类token的权重;第二个元素是第二类token的权重。 - - scatter_indexes: 用于scattering的索引。第一个元素是第一类token的索引;第二个元素是第二类token的索引。 + - dispatch_masks: 用于dispatching的mask。 + - combine_weights:用于combining的权重。 + - scatter_indexes: 用于scattering的索引。 - loss_aux: aux loss。 - loss_z: z loss。 """ @@ -326,7 +277,7 @@ def top2gating( indices1_s = paddle.argmax(gates, axis=1) # [S, 1] mask1 = self._one_hot_to_int64(indices1_s, self.num_experts) # [S, E] - if top2_2nd_expert_sampling: + if self.top2_2nd_expert_sampling: # Create a mask for 2nd's expert per token using Gumbel-max trick. # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ logits += self.gumbel_rsample(logits) @@ -351,9 +302,9 @@ def top2gating( # gating decisions exp_counts = paddle.sum(mask1 + mask2, axis=0) - if drop_tokens: + if self.drop_tokens: # Calculate configured capacity and remove locations outside capacity from mask - capacity = self._capacity(gates, capacity_factor, min_capacity) + capacity = self._capacity(gates, self.capacity_factor, self.min_capacity) # Remove locations outside capacity from mask. mask1 *= (locations1 < capacity).cast(paddle.int64) mask2 *= (locations2 < capacity).cast(paddle.int64) @@ -394,17 +345,10 @@ def top2gating( def topkgating( self, logits: paddle.Tensor, - k: int, - capacity_factor: float, - min_capacity: int, - drop_tokens: bool = True, - drop_policy: str = "probs", ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Implements TopKGating on logits.""" - - # everything is in fp32 in this function # get topk gates - top_gate, top_idx = paddle.topk(logits, k=k, axis=1) + top_gate, top_idx = paddle.topk(logits, k=self.top_k, axis=1) # gating decisions gates = self.gate_score_func(logits=logits) # get topk mask @@ -415,24 +359,22 @@ def topkgating( l_aux = self._cal_aux_loss(gates, mask) l_zloss = self._cal_z_loss(logits) - if drop_tokens: + if self.drop_tokens: # Calculate configured capacity and remove locations outside capacity from mask - capacity = self._capacity(gates, capacity_factor * k, min_capacity) + capacity = self._capacity(gates, self.capacity_factor * self.top_k, self.min_capacity) # update mask and locations by capacity - if drop_policy == "probs": + if self.drop_policy == "probs": capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) - capacity_mask = paddle.zeros_like(logits).put_along_axis( - capacity_indices, paddle.to_tensor(1.0), axis=0 - ) + capacity_mask = paddle.zeros_like(logits).put_along_axis(capacity_indices, paddle.to_tensor(1.0), axis=0) # fmt:skip mask = mask * capacity_mask locations = paddle.cumsum(mask, axis=0) - 1 - elif drop_policy == "position": + elif self.drop_policy == "position": locations = paddle.cumsum(mask, axis=0) - 1 mask *= (locations < capacity).cast(paddle.int64) else: - raise ValueError(f"Invalid drop_policy: {drop_policy}") + raise ValueError(f"Invalid drop_policy: {self.drop_policy}") else: # Do not drop tokens - set capacity according to current expert assignments @@ -458,19 +400,21 @@ def forward(self, hidden_states): raise NotImplementedError("Please implement the forward function.") -class TopKGate(BaseGate): +class TopKGate(PretrainedMoEGate): def __init__( self, num_experts, expert_hidden_size, weight_attr=None, bias_attr=None, - topk=2, + top_k=2, + capacity_factor=1.0, + eval_capacity_factor=1.0, scoring_func="softmax", scaling_attr=None, ): super().__init__(num_experts, expert_hidden_size, weight_attr, bias_attr) - self.topk = topk + self.top_k = top_k self.scoring_func = scoring_func self.scaling_attr = scaling_attr @@ -482,7 +426,7 @@ def forward( bsz, seq_len, hidden_size = hidden_states.shape hidden_states = hidden_states.reshape([-1, hidden_size]) logits = F.linear(x=paddle.cast(hidden_states, paddle.float32), weight=self.weight, bias=self.bias) - if self.topk == 1: + if self.top_k == 1: gate_output = self.top1gating( logits, self.capacity_factor if self.training else self.eval_capacity_factor, @@ -491,7 +435,7 @@ def forward( self.noisy_gate_policy if self.training else None, self.drop_tokens, ) - elif self.topk == 2: + elif self.top_k == 2: gate_output = self.top2gating( logits, self.capacity_factor if self.training else self.eval_capacity_factor, @@ -502,53 +446,10 @@ def forward( else: gate_output = self.topkgating( logits, - self.topk, + self.top_k, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity, self.drop_tokens, ) return gate_output - - -class GroupTopKGate(BaseGate): - def __init__( - self, - num_experts, - expert_hidden_size, - weight_attr=None, - bias_attr=None, - topk=2, - scoring_func="softmax", - scaling_attr=None, - n_group=1, - topk_group=1, - ): - super().__init__(num_experts, expert_hidden_size, weight_attr, bias_attr) - self.topk = topk - self.scoring_func = scoring_func - self.scaling_attr = scaling_attr - self.n_group = n_group - self.topk_group = topk_group - - def forward(self, hidden_states): - bsz, seq_len, hidden_size = hidden_states.shape - hidden_states = hidden_states.reshape([-1, hidden_size]) - scores = self.gate_score_func(hidden_states) - - group_scores = scores.reshape([bsz * seq_len, self.n_group, -1]).max(axis=-1).values # [n, n_group] - group_idx = paddle.topk(group_scores, k=self.topk_group, axis=-1, sorted=False)[1] # [n, top_k_group] - group_mask = paddle.zeros_like(group_scores).scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(bsz * seq_len, self.n_group, self.num_experts // self.n_group) - .reshape(bsz * seq_len, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - - topk_weight, topk_idx = paddle.topk(tmp_scores, k=self.topk, axis=-1, largest=True, sorted=False) - - if self.scaling_attr is not None: - topk_weight = self.scaling_weight(topk_weight) - - return topk_weight, topk_idx, scores From 39d16606c7dd35bf32601dca96a05e70ec0c5f97 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Wed, 16 Oct 2024 06:08:27 +0000 Subject: [PATCH 03/20] update --- paddlenlp/transformers/moe_gate.py | 83 ++++++++++++++++++--------- paddlenlp/transformers/moe_layer.py | 89 +++++++++++++++-------------- 2 files changed, 102 insertions(+), 70 deletions(-) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index 3642a64af6f3..87549df17515 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -144,29 +144,13 @@ def __init__(self, num_experts, expert_hidden_size, **kwargs): self.num_experts = num_experts self.expert_hidden_size = expert_hidden_size + # force keep in float32 when using amp + self._cast_to_low_precision = False + self.capacity_factor = kwargs["capacity_factor"] if hasattr(kwargs, "capacity_factor") else 1.0 # fmt:skip self.eval_capacity_factor = kwargs["eval_capacity_factor"] if hasattr(kwargs, "eval_capacity_factor") else 1.0 # fmt:skip self.min_capacity = kwargs["min_capacity"] if hasattr(kwargs, "min_capacity") else 1.0 # fmt:skip - # force keep in float32 when using amp - self._cast_to_low_precision = False - self._weight_attr = kwargs["weight_attr"] if hasattr(kwargs, "weight_attr") else None # fmt:skip - self._bias_attr = kwargs["bias_attr"] if hasattr(kwargs, "bias_attr") else None # fmt:skip - - self.weight = paddle.create_parameter( - shape=[self.expert_hidden_size, self.num_experts], - attr=self._weight_attr, - dtype="float32", - is_bias=False, - ) - if self._bias_attr is not None and self._bias_attr: - self.bias = paddle.create_parameter( - shape=[self.num_experts], - attr=self._bias_attr, - dtype="float32", - is_bias=True, - ) - self.group = kwargs["group"] if hasattr(kwargs, "group") else None self.global_aux_loss = kwargs["global_aux_loss"] if hasattr(kwargs, "global_aux_loss") else False if self.global_aux_loss: @@ -183,6 +167,53 @@ def __init__(self, num_experts, expert_hidden_size, **kwargs): self.drop_policy = kwargs["drop_policy"] if hasattr(kwargs, "drop_policy") else "probs" self.top_k = kwargs["top_k"] if hasattr(kwargs, "top_k") else 1 + def topk_navie(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + """ + topk_weight, topk_idx = paddle.topk(scores, k=k, axis=-1, sorted=False) + return topk_weight, topk_idx + + def topk_group( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + group_scores = scores.reshape([0, n_group, -1]).max(axis=-1) # [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=False)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=False) + + return topk_weight, topk_idx + def top1gating( self, logits: paddle.Tensor, @@ -344,20 +375,16 @@ def top2gating( def topkgating( self, - logits: paddle.Tensor, + gates: paddle.Tensor, ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Implements TopKGating on logits.""" # get topk gates - top_gate, top_idx = paddle.topk(logits, k=self.top_k, axis=1) - # gating decisions - gates = self.gate_score_func(logits=logits) + top_gate, top_idx = paddle.topk(gates, k=self.top_k, axis=1) # get topk mask - topk_masked_gates = paddle.zeros_like(logits).put_along_axis(top_idx, top_gate, axis=1) mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) exp_counts = paddle.sum(mask, axis=0) l_aux = self._cal_aux_loss(gates, mask) - l_zloss = self._cal_z_loss(logits) if self.drop_tokens: # Calculate configured capacity and remove locations outside capacity from mask @@ -365,8 +392,9 @@ def topkgating( # update mask and locations by capacity if self.drop_policy == "probs": + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) - capacity_mask = paddle.zeros_like(logits).put_along_axis(capacity_indices, paddle.to_tensor(1.0), axis=0) # fmt:skip + capacity_mask = paddle.zeros_like(gates).put_along_axis(capacity_indices, paddle.to_tensor(1.0), axis=0) # fmt:skip mask = mask * capacity_mask locations = paddle.cumsum(mask, axis=0) - 1 @@ -375,7 +403,6 @@ def topkgating( mask *= (locations < capacity).cast(paddle.int64) else: raise ValueError(f"Invalid drop_policy: {self.drop_policy}") - else: # Do not drop tokens - set capacity according to current expert assignments new_capacity = paddle.max(exp_counts) @@ -394,7 +421,7 @@ def topkgating( combine_weights = paddle.einsum("se,sec->sec", gates_masked, locations_sc) dispatch_mask = combine_weights.cast(paddle.bool) - return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux def forward(self, hidden_states): raise NotImplementedError("Please implement the forward function.") diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 7edd398e7b44..68656f6b38ba 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -1,4 +1,6 @@ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -173,6 +175,7 @@ class MoELayer(nn.Layer): def __init__( self, gate: nn.Layer, + capacity: int, experts: List[nn.Layer], layer_idx, group: Group = None, @@ -182,9 +185,20 @@ def __init__( ): super().__init__() self.gate = gate - self.layer_idx = layer_idx - self.recompute = recompute - logger.info(f"using moe recompute={recompute}") + + self.num_experts = len(experts) + self.experts = experts + self.capacity = capacity + + self.group = group + self.all_to_all_dropout = all_to_all_dropout + + self.enable_recompute = False + + self.expert_parallel_degree = 1 if dist.get_world_size(self.group) < 1 else dist.get_world_size(group) + is_dummy_moe = dist.get_world_size(group) == 1 + self.rank = 0 if dist.get_rank(self.group) < 0 else dist.get_rank(self.group) + for p in self.gate.parameters(): p.is_gate = True if type(experts) == nn.LayerList: @@ -203,14 +217,10 @@ def __init__( p.no_sync = not is_dummy_moe # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") - self.world_size = dist.get_world_size(group) - self.rank = dist.get_rank(self.group) - if self.world_size < 1: - self.world_size = 1 - if self.rank < 0: - self.rank = 0 - - self.num_local_experts = moe_num_experts // self.world_size + assert ( + self.num_experts // self.expert_parallel_degree == 0 + ), f"num_experts must be divisible by expert_parallel_degree, got: {self.num_experts} vs {self.expert_parallel_degree}" + self.num_local_experts = self.num_experts // self.expert_parallel_degree def forward(self, input): true_experts = self.experts[self.rank * self.num_local_experts : (self.rank + 1) * self.num_local_experts] @@ -238,39 +248,34 @@ def fwdfn(dispatched_input): expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] return expert_output - assert self.gate is not None - if hasattr(self, "rng") and self.rng.random() < self.all_to_all_dropout: - orig_shape_2 = input.shape - input = input.reshape([self.world_size, self.num_local_experts, -1, input.shape[-1]]) - output = fwdfn(input) - output += self.gate.weight.sum() * 0.0 # hack for grad - output = output.reshape(orig_shape or orig_shape_2) # [e*1,c,m] - return output, None, 0 - - capacity, dispatch_mask, combine_weights, scatter_index, router_loss = self.gate(input) - self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) - - if self.world_size > 1: + # Initial implementation -> Reshape into S tokens by dropping sequence dimension. + # Reshape into G groups so that each group can distribute tokens equally + # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 + reshaped_input = input[0].reshape(-1, d_model) + self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input) + # self.l_aux : + # combine_weights : sec + # dispatch_mask : sec + # self.exp_counts : + dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, input.dtype), reshaped_input) + + # capacity, dispatch_mask, combine_weights, scatter_index, router_loss = self.gate(input) + # self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) + + if self.expert_parallel_degree > 1: dispatched_input = _AllToAll.apply(dispatched_input, self.group) - dispatched_input = dispatched_input.reshape([self.world_size * self.num_local_experts, capacity, d_model]) - expert_output = ( - recompute(fwdfn, dispatched_input) if self.recompute and self.training else fwdfn(dispatched_input) + + # Re-shape after all-to-all: ecm -> gecm + dispatched_input = dispatched_input.reshape( + [self.expert_parallel_degree * self.num_local_experts, -1, d_model] ) - d_model_out = expert_output.shape[-1] + expert_output = fwdfn(dispatched_input) - if self.world_size > 1: + # Re-shape before drop_tokens: gecm -> ecm + expert_output = expert_output.reshape(self.expert_parallel_degree * self.num_local_experts, -1, d_model) + if self.expert_parallel_degree > 1: expert_output = _AllToAll.apply(expert_output, self.group) # 拿到不同device上的expert计算结果 + combined_output = paddle.einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output) - expert_output = expert_output.reshape( - [self.world_size * self.num_local_experts * capacity, d_model_out] - ) # [e * 1, c, m] - combined_output = combining(expert_output, combine_weights, scatter_index) - - if orig_shape: - combined_output = combined_output.reshape( - orig_shape[:-1] - + [ - d_model_out, - ] - ) - return combined_output, combine_weights, router_loss + a = combined_output.reshape(input[0].shape) + return a From cc41578db6b6742c799fce7ed6d94c33dca90a11 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Mon, 21 Oct 2024 08:24:21 +0000 Subject: [PATCH 04/20] update base methods --- paddlenlp/transformers/__init__.py | 2 ++ paddlenlp/transformers/moe_layer.py | 27 ++++----------------------- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index ab7510e0897e..32de7553992f 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -32,6 +32,8 @@ from .attention_utils import create_bigbird_rand_mask_idx_list from .sequence_parallel_utils import AllGatherVarlenOp, sequence_parallel_sparse_mask_labels from .tensor_parallel_utils import parallel_matmul, parallel_linear, fused_head_and_loss_fn +from .moe_gate import * +from .moe_layer import * try: from paddle.distributed.fleet.utils.sequence_parallel_utils import ( diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 68656f6b38ba..521a6edd277a 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -14,30 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple -from contextlib import contextmanager -from typing import Any, List, Tuple +from typing import Any, Tuple import paddle import paddle.distributed as dist -import paddle.nn.functional as F from paddle import Tensor, nn -from paddle.distributed import fleet from paddle.distributed.communication import stream from paddle.distributed.communication.group import Group -from paddle.distributed.fleet.utils import recompute from ..utils.log import logger -GateOutput = namedtuple( - "GateOutput", - [ - "aux", - "z", - "logits", - ], -) - def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): """ @@ -176,12 +162,9 @@ def __init__( self, gate: nn.Layer, capacity: int, - experts: List[nn.Layer], - layer_idx, + experts: nn.LayerList, group: Group = None, - recompute=False, all_to_all_dropout=0.0, - moe_num_experts=2, ): super().__init__() self.gate = gate @@ -201,7 +184,8 @@ def __init__( for p in self.gate.parameters(): p.is_gate = True - if type(experts) == nn.LayerList: + + if isinstance(experts, nn.LayerList): self.experts = experts else: logger.info(f"using fused experts, type={type(experts)}") @@ -225,10 +209,7 @@ def __init__( def forward(self, input): true_experts = self.experts[self.rank * self.num_local_experts : (self.rank + 1) * self.num_local_experts] if input.ndim == 3: - orig_shape = input.shape reshaped_input = input.reshape([-1, input.shape[-1]]) - else: - orig_shape = None assert len(input.shape) == 2, f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" # Implement Algorithm 2 from GShard paper. From 2b74f3052ad58df191460dfef3218d341dce8524 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Thu, 24 Oct 2024 02:58:19 +0000 Subject: [PATCH 05/20] update moe_layer --- paddlenlp/transformers/moe_layer.py | 47 ++++++++++++++++------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 521a6edd277a..94d57f110efd 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -202,18 +202,24 @@ def __init__( # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") assert ( - self.num_experts // self.expert_parallel_degree == 0 + self.num_experts % self.expert_parallel_degree == 0 ), f"num_experts must be divisible by expert_parallel_degree, got: {self.num_experts} vs {self.expert_parallel_degree}" self.num_local_experts = self.num_experts // self.expert_parallel_degree - def forward(self, input): + def forward(self, hidden_state): + """_summary_ + + Args: + input (_type_): _description_ + + Returns: + _type_: _description_ + """ + true_experts = self.experts[self.rank * self.num_local_experts : (self.rank + 1) * self.num_local_experts] - if input.ndim == 3: - reshaped_input = input.reshape([-1, input.shape[-1]]) - assert len(input.shape) == 2, f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" # Implement Algorithm 2 from GShard paper. - seqlen, d_model = input.shape + batch_size, seq_len, d_model = hidden_state.shape # Reshape into S tokens by dropping sequence dimension. # reshaped_input = input.reshape(-1, d_model) @@ -232,31 +238,30 @@ def fwdfn(dispatched_input): # Initial implementation -> Reshape into S tokens by dropping sequence dimension. # Reshape into G groups so that each group can distribute tokens equally # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 - reshaped_input = input[0].reshape(-1, d_model) - self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input) + reshaped_input = hidden_state.reshape([-1, d_model]) + + capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate(reshaped_input) + # self.l_aux, combine_weights, dispatch_mask, self.exp_counts = # self.l_aux : # combine_weights : sec # dispatch_mask : sec # self.exp_counts : - dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, input.dtype), reshaped_input) - - # capacity, dispatch_mask, combine_weights, scatter_index, router_loss = self.gate(input) - # self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) + dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) if self.expert_parallel_degree > 1: dispatched_input = _AllToAll.apply(dispatched_input, self.group) # Re-shape after all-to-all: ecm -> gecm - dispatched_input = dispatched_input.reshape( - [self.expert_parallel_degree * self.num_local_experts, -1, d_model] - ) + dispatched_input = dispatched_input.reshape([self.expert_parallel_degree, self.num_local_experts, -1, d_model]) expert_output = fwdfn(dispatched_input) - # Re-shape before drop_tokens: gecm -> ecm - expert_output = expert_output.reshape(self.expert_parallel_degree * self.num_local_experts, -1, d_model) - if self.expert_parallel_degree > 1: - expert_output = _AllToAll.apply(expert_output, self.group) # 拿到不同device上的expert计算结果 - combined_output = paddle.einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output) + expert_output = expert_output.reshape([self.expert_parallel_degree * self.num_local_experts, -1, d_model]) + + expert_output = _AllToAll.apply(expert_output, self.group) + + # Re拿到不同device上的expert计算结果 + combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) + + a = combined_output.reshape(hidden_state[0].shape) - a = combined_output.reshape(input[0].shape) return a From f5174730ff9c17b01564dd58fa7617a876756252 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Thu, 24 Oct 2024 09:34:24 +0000 Subject: [PATCH 06/20] update moebase --- paddlenlp/transformers/moe_gate.py | 27 +++++++++--------- paddlenlp/transformers/moe_layer.py | 44 +++++++++++++++-------------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index 87549df17515..f022fc07be88 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -147,25 +147,24 @@ def __init__(self, num_experts, expert_hidden_size, **kwargs): # force keep in float32 when using amp self._cast_to_low_precision = False - self.capacity_factor = kwargs["capacity_factor"] if hasattr(kwargs, "capacity_factor") else 1.0 # fmt:skip - self.eval_capacity_factor = kwargs["eval_capacity_factor"] if hasattr(kwargs, "eval_capacity_factor") else 1.0 # fmt:skip - self.min_capacity = kwargs["min_capacity"] if hasattr(kwargs, "min_capacity") else 1.0 # fmt:skip + self.capacity_factor = kwargs.pop("capacity_factor", 1.0) + self.eval_capacity_factor = kwargs.pop("eval_capacity_factor", 1.0) + self.min_capacity = kwargs.pop("min_capacity", 1.0) - self.group = kwargs["group"] if hasattr(kwargs, "group") else None - self.global_aux_loss = kwargs["global_aux_loss"] if hasattr(kwargs, "global_aux_loss") else False + self.group = kwargs.pop("group", None) + self.global_aux_loss = kwargs.pop("global_aux_loss", False) if self.global_aux_loss: assert self.group is not None, "group is required when global_aux_loss is True" self.rank = dist.get_rank(self.group) - self.expert_drop = kwargs["expert_drop"] if hasattr(kwargs, "expert_drop") else False - self.noisy_gate_policy = kwargs["noisy_gate_policy"] if hasattr(kwargs, "noisy_gate_policy") else None - self.drop_tokens = kwargs["drop_tokens"] if hasattr(kwargs, "drop_tokens") else True - self.use_rts = kwargs["use_rts"] if hasattr(kwargs, "use_rts") else True - self.top2_2nd_expert_sampling = ( - kwargs["top2_2nd_expert_sampling"] if hasattr(kwargs, "top2_2nd_expert_sampling") else True - ) - self.drop_policy = kwargs["drop_policy"] if hasattr(kwargs, "drop_policy") else "probs" - self.top_k = kwargs["top_k"] if hasattr(kwargs, "top_k") else 1 + self.expert_drop = kwargs.pop("expert_drop", False) + self.noisy_gate_policy = kwargs.pop("noisy_gate_policy", None) + self.drop_tokens = kwargs.pop("drop_tokens", True) + self.use_rts = kwargs.pop("use_rts", True) + self.top2_2nd_expert_sampling = kwargs.pop("top2_2nd_expert_sampling", True) + + self.drop_policy = kwargs.pop("drop_policy", "probs") + self.top_k = kwargs.pop("top_k", 2) def topk_navie(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: """_summary_ diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 94d57f110efd..fa0ccdf12a55 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -206,41 +206,42 @@ def __init__( ), f"num_experts must be divisible by expert_parallel_degree, got: {self.num_experts} vs {self.expert_parallel_degree}" self.num_local_experts = self.num_experts // self.expert_parallel_degree - def forward(self, hidden_state): + def expert_forward(self, dispatched_input): + true_experts = self.experts[self.rank * self.num_local_experts : (self.rank + 1) * self.num_local_experts] + expert_outputs = [] + chunks = dispatched_input.unbind(1) + assert len(chunks) == len(true_experts), (len(chunks), len(true_experts)) + for chunk, expert in zip(chunks, true_experts): + chunk = chunk.contiguous() + expert_outputs += [expert(chunk)] + expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + return expert_output + + def forward( + self, + hidden_state: paddle.Tensor, + used_token: paddle.Tensor = None, + ): """_summary_ Args: input (_type_): _description_ + used_token Returns: _type_: _description_ """ - - true_experts = self.experts[self.rank * self.num_local_experts : (self.rank + 1) * self.num_local_experts] - # Implement Algorithm 2 from GShard paper. batch_size, seq_len, d_model = hidden_state.shape - # Reshape into S tokens by dropping sequence dimension. - # reshaped_input = input.reshape(-1, d_model) - # assert reshaped_input.shape[0] % len(self.experts) == 0, \ - # f'num tokens must be order of number of local experts, {input[0].shape[0]} vs {len(self.experts)}' - def fwdfn(dispatched_input): - expert_outputs = [] - chunks = dispatched_input.unbind(1) - assert len(chunks) == len(true_experts), (len(chunks), len(true_experts)) - for chunk, expert in zip(chunks, true_experts): - chunk = chunk.contiguous() - expert_outputs += [expert(chunk)] - expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] - return expert_output - # Initial implementation -> Reshape into S tokens by dropping sequence dimension. # Reshape into G groups so that each group can distribute tokens equally # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 reshaped_input = hidden_state.reshape([-1, d_model]) capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate(reshaped_input) + + print(f"capacity={capacity}") # self.l_aux, combine_weights, dispatch_mask, self.exp_counts = # self.l_aux : # combine_weights : sec @@ -253,15 +254,16 @@ def fwdfn(dispatched_input): # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape([self.expert_parallel_degree, self.num_local_experts, -1, d_model]) - expert_output = fwdfn(dispatched_input) + expert_output = self.expert_forward(dispatched_input) # Re-shape before drop_tokens: gecm -> ecm expert_output = expert_output.reshape([self.expert_parallel_degree * self.num_local_experts, -1, d_model]) - expert_output = _AllToAll.apply(expert_output, self.group) + if self.expert_parallel_degree > 1: + expert_output = _AllToAll.apply(expert_output, self.group) # Re拿到不同device上的expert计算结果 combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) - a = combined_output.reshape(hidden_state[0].shape) + a = combined_output.reshape(hidden_state.shape) return a From 1a3399ef3f2910fc31e769496d14bc73dba5e6e5 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Thu, 24 Oct 2024 09:48:57 +0000 Subject: [PATCH 07/20] add moe_gate and moe_layer for qwen2moe --- paddlenlp/transformers/qwen2_moe/modeling.py | 95 ++++++++++---------- 1 file changed, 49 insertions(+), 46 deletions(-) diff --git a/paddlenlp/transformers/qwen2_moe/modeling.py b/paddlenlp/transformers/qwen2_moe/modeling.py index 46a3bb885a60..b5d84749fab8 100644 --- a/paddlenlp/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/transformers/qwen2_moe/modeling.py @@ -34,6 +34,8 @@ from ..conversion_utils import StateDictNameMapping, init_name_mappings from ..model_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPast from ..model_utils import PretrainedModel, register_base_model +from ..moe_gate import PretrainedMoEGate +from ..moe_layer import MoELayer from .configuration import Qwen2MoeConfig try: @@ -683,68 +685,69 @@ def forward( return outputs -class Qwen2MoeSparseMoEBlock(nn.Layer): +class Qwen2MoeGate(PretrainedMoEGate): + def __init__(self, num_experts, expert_hidden_size, **kwargs): + super().__init__(num_experts, expert_hidden_size, **kwargs) + # [hidden_size, n_expert] + self.weight = paddle.create_parameter( + shape=[expert_hidden_size, num_experts], + dtype=paddle.get_default_dtype(), + is_bias=False, + default_initializer=nn.initializer.Constant(1.0), + ) + + def forward(self, hidden_states): + """ + Args: + hidden_states (_type_): [batch_size * seq_len, hidden_size] + """ + _, h_dim = hidden_states.shape + + # compute gating score + hidden_states = hidden_states.reshape([-1, h_dim]) + + with paddle.amp.auto_cast(False): + logits = F.linear(hidden_states.cast(paddle.float32), self.weight, None) + + scores = self.gate_score_func(logits=logits) + + # topk_weight, topk_idx = self.topk_navie(scores, k=2) + # topk_weight, topk_idx = self.topk_group(scores, k=2, n_group=4, topk_group=2) + + capacity, combine_weights, dispatch_mask, exp_counts, l_aux = self.topkgating(scores) + l_zloss = self._cal_z_loss(logits) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + +class Qwen2MoeSparseMoEBlock(MoELayer): def __init__(self, config: Qwen2MoeConfig): - super().__init__() + self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob - self.gate = nn.Linear(config.hidden_size, self.num_experts, bias_attr=False) + self.gate = Qwen2MoeGate(self.num_experts, config.hidden_size) + # self.gate = nn.Linear(config.hidden_size, self.num_experts, bias_attr=False)s self.experts = nn.LayerList([Qwen2MoeMLP(config) for _ in range(self.num_experts)]) + super().__init__( + gate=self.gate, + capacity=2.0, + experts=self.experts, + ) + self.shared_expert = Qwen2MoeMLP(config, is_shared=True) self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias_attr=False) def forward(self, hidden_states): - batch_size, seq_len, hidden_dim = hidden_states.shape - hidden_states = hidden_states.reshape([-1, hidden_dim]) - # router_logits: [batch_size * seq_len, num_experts] - router_logits = self.gate(hidden_states) - - with paddle.amp.auto_cast(False): - routing_weights = F.softmax(router_logits.astype("float32"), axis=1) - routing_weights, selected_experts = paddle.topk(routing_weights, self.top_k, axis=-1) - if self.norm_topk_prob: # Note: Mixtral is set norm as default, Qwen2Moe is set to no norm - routing_weights /= routing_weights.sum(axis=-1, keepdim=True) - # we cast back to input dtype - routing_weights = routing_weights.astype(hidden_states.dtype) - - final_hidden_states = paddle.zeros( - [batch_size * seq_len, hidden_dim], - dtype=hidden_states.dtype, - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated. - # shape: [num_experts, top_k, batch_size * seq_len] - expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).transpose([2, 1, 0]) - - # Loop over all available experts in the model and perform the computation on each expert. - for expert_id in range(self.num_experts): - expert_layer = self.experts[expert_id] - idx, top_x = paddle.where(expert_mask[expert_id]) - - if top_x.shape[0] == 0: - continue - - current_state = paddle.gather(hidden_states, top_x.squeeze()) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx] - - top_x = top_x.squeeze() - if top_x.shape == []: - top_x = paddle.to_tensor([top_x.item()]) - final_hidden_states = paddle.index_add_( - final_hidden_states, top_x, 0, current_hidden_states.astype(hidden_states.dtype) - ) + final_hidden_states = super().forward(hidden_states) shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output - final_hidden_states = final_hidden_states + shared_expert_output - final_hidden_states = final_hidden_states.reshape([batch_size, seq_len, hidden_dim]) - return final_hidden_states, router_logits + return final_hidden_states, 0.0 class Qwen2MoeDecoderLayer(nn.Layer): From d6a16eb49b7069deb7d336678ce74a575b546ff0 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Wed, 30 Oct 2024 02:28:24 +0000 Subject: [PATCH 08/20] add config --- llm/config/qwen2moe/lora_argument.json | 34 ++++++++++++++++++ llm/config/qwen2moe/pretrain_argument.json | 40 ++++++++++++++++++++++ llm/config/qwen2moe/sft_argument.json | 33 ++++++++++++++++++ 3 files changed, 107 insertions(+) create mode 100644 llm/config/qwen2moe/lora_argument.json create mode 100644 llm/config/qwen2moe/pretrain_argument.json create mode 100644 llm/config/qwen2moe/sft_argument.json diff --git a/llm/config/qwen2moe/lora_argument.json b/llm/config/qwen2moe/lora_argument.json new file mode 100644 index 000000000000..47e7adb14ecd --- /dev/null +++ b/llm/config/qwen2moe/lora_argument.json @@ -0,0 +1,34 @@ +{ + "model_name_or_path": "Qwen/Qwen2-57B-A14B", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/lora_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-04, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "lora": true, + "unified_checkpoint": true, + "zero_padding": false, + "use_flash_attention": true, + "pissa": false + } diff --git a/llm/config/qwen2moe/pretrain_argument.json b/llm/config/qwen2moe/pretrain_argument.json new file mode 100644 index 000000000000..f3115a64b648 --- /dev/null +++ b/llm/config/qwen2moe/pretrain_argument.json @@ -0,0 +1,40 @@ +{ + "model_name_or_path": "Qwen/Qwen2-57B-A14B", + "tokenizer_name_or_path": "Qwen/Qwen2-57B-A14B", + "input_dir": "./data", + "output_dir": "./checkpoints/pretrain_ckpts", + "per_device_train_batch_size": 2, + "gradient_accumulation_steps": 1, + "per_device_eval_batch_size": 2, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "sharding": "stage2", + "virtual_pp_degree": 1, + "sequence_parallel": 0, + "use_flash_attention": true, + "use_fused_rms_norm": true, + "max_seq_length": 4096, + "learning_rate": 3e-05, + "min_learning_rate": 3e-06, + "warmup_steps": 30, + "logging_steps": 1, + "max_steps": 10000, + "save_steps": 5000, + "eval_steps": 1000, + "weight_decay": 0.01, + "bf16": true, + "fp16_opt_level": "O2", + "warmup_ratio": 0.01, + "max_grad_norm": 1.0, + "dataloader_num_workers": 1, + "continue_training": 1, + "do_train": true, + "do_eval": true, + "do_predict": true, + "disable_tqdm": true, + "recompute": true, + "distributed_dataloader": 1, + "recompute_granularity": "full", + "unified_checkpoint": true, + "save_total_limit": 2 + } diff --git a/llm/config/qwen2moe/sft_argument.json b/llm/config/qwen2moe/sft_argument.json new file mode 100644 index 000000000000..c964137f2264 --- /dev/null +++ b/llm/config/qwen2moe/sft_argument.json @@ -0,0 +1,33 @@ +{ + "model_name_or_path": "Qwen/Qwen2-57B-A14B", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/sft_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-05, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "sharding": "stage2", + "zero_padding": false, + "unified_checkpoint": true, + "use_flash_attention": true + } From fad1a4f71288aa136a0e40046032e6a11ea8d654 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Thu, 31 Oct 2024 01:53:13 +0000 Subject: [PATCH 09/20] update --- paddlenlp/transformers/moe_layer.py | 37 +++++++------------- paddlenlp/transformers/qwen2_moe/modeling.py | 10 ++---- 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index fa0ccdf12a55..0f82e9ec1f02 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -160,17 +160,14 @@ class MoELayer(nn.Layer): def __init__( self, - gate: nn.Layer, - capacity: int, - experts: nn.LayerList, + num_experts: int, + capacity: int = 1.0, group: Group = None, all_to_all_dropout=0.0, ): super().__init__() - self.gate = gate - self.num_experts = len(experts) - self.experts = experts + self.num_experts = num_experts self.capacity = capacity self.group = group @@ -179,33 +176,25 @@ def __init__( self.enable_recompute = False self.expert_parallel_degree = 1 if dist.get_world_size(self.group) < 1 else dist.get_world_size(group) - is_dummy_moe = dist.get_world_size(group) == 1 + self.is_dummy_moe = dist.get_world_size(self.group) == 1 self.rank = 0 if dist.get_rank(self.group) < 0 else dist.get_rank(self.group) + assert ( + self.num_experts % self.expert_parallel_degree == 0 + ), f"num_experts must be divisible by expert_parallel_degree, got: {self.num_experts} vs {self.expert_parallel_degree}" + self.num_local_experts = self.num_experts // self.expert_parallel_degree + + def _post_init(self): for p in self.gate.parameters(): p.is_gate = True - if isinstance(experts, nn.LayerList): - self.experts = experts - else: - logger.info(f"using fused experts, type={type(experts)}") - self.experts = nn.LayerList([experts]) - self.group = group - self.all_to_all_dropout = all_to_all_dropout - is_dummy_moe = dist.get_world_size(group) == 1 or dist.get_world_size(group) == -1 - - for k in experts: + for k in self.experts: if k is not None: for p in k.parameters(): - p.expert = not is_dummy_moe - p.no_sync = not is_dummy_moe + p.expert = not self.is_dummy_moe + p.no_sync = not self.is_dummy_moe # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") - assert ( - self.num_experts % self.expert_parallel_degree == 0 - ), f"num_experts must be divisible by expert_parallel_degree, got: {self.num_experts} vs {self.expert_parallel_degree}" - self.num_local_experts = self.num_experts // self.expert_parallel_degree - def expert_forward(self, dispatched_input): true_experts = self.experts[self.rank * self.num_local_experts : (self.rank + 1) * self.num_local_experts] expert_outputs = [] diff --git a/paddlenlp/transformers/qwen2_moe/modeling.py b/paddlenlp/transformers/qwen2_moe/modeling.py index b5d84749fab8..edd6b568e785 100644 --- a/paddlenlp/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/transformers/qwen2_moe/modeling.py @@ -722,21 +722,15 @@ def forward(self, hidden_states): class Qwen2MoeSparseMoEBlock(MoELayer): def __init__(self, config: Qwen2MoeConfig): + super().__init__(num_experts=config.num_experts, capacity=2.0) - self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob self.gate = Qwen2MoeGate(self.num_experts, config.hidden_size) - # self.gate = nn.Linear(config.hidden_size, self.num_experts, bias_attr=False)s + # self.gate = nn.Linear(config.hidden_size, self.num_experts, bias_attr=False) self.experts = nn.LayerList([Qwen2MoeMLP(config) for _ in range(self.num_experts)]) - super().__init__( - gate=self.gate, - capacity=2.0, - experts=self.experts, - ) - self.shared_expert = Qwen2MoeMLP(config, is_shared=True) self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias_attr=False) From 8701b5235380a837c3af730ec1fb34bda896ac34 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Mon, 4 Nov 2024 02:35:24 +0000 Subject: [PATCH 10/20] update gate dtype --- paddlenlp/transformers/qwen2_moe/modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/qwen2_moe/modeling.py b/paddlenlp/transformers/qwen2_moe/modeling.py index edd6b568e785..0695f9f607d2 100644 --- a/paddlenlp/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/transformers/qwen2_moe/modeling.py @@ -691,7 +691,7 @@ def __init__(self, num_experts, expert_hidden_size, **kwargs): # [hidden_size, n_expert] self.weight = paddle.create_parameter( shape=[expert_hidden_size, num_experts], - dtype=paddle.get_default_dtype(), + dtype=paddle.float32, is_bias=False, default_initializer=nn.initializer.Constant(1.0), ) From 448ecbdd6c42344f73a0bdd94b99c2dd453a727d Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Mon, 4 Nov 2024 09:29:26 +0000 Subject: [PATCH 11/20] update moe gate and layer --- paddlenlp/transformers/moe_gate.py | 4 +- paddlenlp/transformers/moe_layer.py | 68 ++++++++++++++++---- paddlenlp/transformers/qwen2_moe/modeling.py | 15 +++-- 3 files changed, 68 insertions(+), 19 deletions(-) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index f022fc07be88..e6103a617bce 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -41,7 +41,9 @@ def gate_score_func(self, logits: paddle.Tensor) -> paddle.Tensor: elif scoring_func == "leaky_relu": scores = F.leaky_relu(logits.cast("float32")) else: - logger.warning(f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead") + logger.warning_once( + f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead" + ) scores = F.softmax(logits.cast("float32"), axis=-1) return scores diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 0f82e9ec1f02..aa434539521d 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Tuple +from copy import deepcopy +from typing import Any, Tuple, Union import paddle import paddle.distributed as dist @@ -160,29 +161,72 @@ class MoELayer(nn.Layer): def __init__( self, - num_experts: int, + config, + moe_num_experts: int, + expert_class: nn.Layer, + expert_kwargs: dict, + gate: nn.Layer, capacity: int = 1.0, - group: Group = None, + moe_group: str = "data", all_to_all_dropout=0.0, ): super().__init__() - self.num_experts = num_experts + self.config = config + + self.moe_num_experts = moe_num_experts self.capacity = capacity - self.group = group - self.all_to_all_dropout = all_to_all_dropout + self.moe_group = self._parse_moe_group(moe_group) # moe_group is str + self.moe_rank = dist.get_rank(self.moe_group) + self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank + self.expert_parallel_degree = dist.get_world_size(self.moe_group) + self.expert_parallel_degree = 1 if self.expert_parallel_degree < 0 else self.expert_parallel_degree + self.moe_num_experts_per_device = self._parse_moe_expert_parallel( + config, self.moe_num_experts, self.expert_parallel_degree + ) + self.all_to_all_dropout = all_to_all_dropout self.enable_recompute = False - self.expert_parallel_degree = 1 if dist.get_world_size(self.group) < 1 else dist.get_world_size(group) - self.is_dummy_moe = dist.get_world_size(self.group) == 1 - self.rank = 0 if dist.get_rank(self.group) < 0 else dist.get_rank(self.group) + self.experts = nn.LayerList([]) + expert = expert_class(**expert_kwargs) + for i in range(self.moe_num_experts): + if i // self.moe_world_size_per_device == self.moe_rank: + self.experts.append(deepcopy(expert)) + else: + self.experts.append(None) + self.gate = gate + + def _parse_moe_group( + self, + moe_group: str = "data", + ) -> Union[str, paddle.distributed.communication.group.Group]: + moe_group = moe_group.lower() + assert moe_group in {"data", "dp", "dummy"}, f"moe-group not supported, got: {moe_group}" + logger.info(f"using moe-group: {moe_group}") + if not hasattr(dist.fleet.fleet, "_hcg"): + assert moe_group in {"dummy"}, "only support dummy gate in `single-model`" + if moe_group in {"data", "dp"}: + moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + elif moe_group in {"dummy"}: + dummy_group = dist.communication.group.Group(0, None, [0]) + moe_group = dummy_group + else: + moe_group = dist.communication.group._get_global_group() # None 为全局通信组 + + return moe_group + + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): + assert ( + moe_num_experts >= expert_parallel_degree + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}" assert ( - self.num_experts % self.expert_parallel_degree == 0 - ), f"num_experts must be divisible by expert_parallel_degree, got: {self.num_experts} vs {self.expert_parallel_degree}" - self.num_local_experts = self.num_experts // self.expert_parallel_degree + moe_num_experts % expert_parallel_degree == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0" + moe_world_size_per_device = moe_num_experts // expert_parallel_degree + return moe_world_size_per_device def _post_init(self): for p in self.gate.parameters(): diff --git a/paddlenlp/transformers/qwen2_moe/modeling.py b/paddlenlp/transformers/qwen2_moe/modeling.py index 0695f9f607d2..36f71ceab2f9 100644 --- a/paddlenlp/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/transformers/qwen2_moe/modeling.py @@ -54,7 +54,7 @@ try: from paddle.nn.functional.flash_attention import flash_attention -except: +except ImportError: flash_attention = None __all__ = [ @@ -722,15 +722,18 @@ def forward(self, hidden_states): class Qwen2MoeSparseMoEBlock(MoELayer): def __init__(self, config: Qwen2MoeConfig): - super().__init__(num_experts=config.num_experts, capacity=2.0) + super().__init__( + config, + moe_num_experts=config.num_experts, + expert_class=Qwen2MoeMLP, + expert_kwargs=config, + gate=Qwen2MoeGate(config.num_experts, config.hidden_size), + capacity=2.0, + ) self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob - self.gate = Qwen2MoeGate(self.num_experts, config.hidden_size) - # self.gate = nn.Linear(config.hidden_size, self.num_experts, bias_attr=False) - self.experts = nn.LayerList([Qwen2MoeMLP(config) for _ in range(self.num_experts)]) - self.shared_expert = Qwen2MoeMLP(config, is_shared=True) self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias_attr=False) From 77ec9b0a6a93133cddb09f377952930d491e8740 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Tue, 5 Nov 2024 03:28:29 +0000 Subject: [PATCH 12/20] update moe_layer.py --- paddlenlp/transformers/moe_layer.py | 71 +++++++++++++---------------- 1 file changed, 31 insertions(+), 40 deletions(-) diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index aa434539521d..96ed534fe99c 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -15,7 +15,7 @@ # limitations under the License. from copy import deepcopy -from typing import Any, Tuple, Union +from typing import Any, Tuple import paddle import paddle.distributed as dist @@ -23,8 +23,6 @@ from paddle.distributed.communication import stream from paddle.distributed.communication.group import Group -from ..utils.log import logger - def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): """ @@ -177,47 +175,34 @@ def __init__( self.moe_num_experts = moe_num_experts self.capacity = capacity - self.moe_group = self._parse_moe_group(moe_group) # moe_group is str - self.moe_rank = dist.get_rank(self.moe_group) - self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank - self.expert_parallel_degree = dist.get_world_size(self.moe_group) - self.expert_parallel_degree = 1 if self.expert_parallel_degree < 0 else self.expert_parallel_degree - self.moe_num_experts_per_device = self._parse_moe_expert_parallel( - config, self.moe_num_experts, self.expert_parallel_degree - ) + if dist.get_world_size() > 1: + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + self.moe_rank = dist.get_rank(self.moe_group) + self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank + self.expert_parallel_degree = dist.get_world_size(self.moe_group) + self.expert_parallel_degree = 1 if self.expert_parallel_degree < 0 else self.expert_parallel_degree + self.moe_num_experts_per_device = self._parse_moe_expert_parallel( + self.moe_num_experts, self.expert_parallel_degree + ) + else: + self.moe_group = None + self.moe_rank = 0 + self.expert_parallel_degree = 1 + self.moe_num_experts_per_device = self.moe_num_experts self.all_to_all_dropout = all_to_all_dropout self.enable_recompute = False self.experts = nn.LayerList([]) - expert = expert_class(**expert_kwargs) + expert = expert_class(expert_kwargs) for i in range(self.moe_num_experts): - if i // self.moe_world_size_per_device == self.moe_rank: + if i // self.moe_num_experts_per_device == self.moe_rank: self.experts.append(deepcopy(expert)) else: self.experts.append(None) self.gate = gate - def _parse_moe_group( - self, - moe_group: str = "data", - ) -> Union[str, paddle.distributed.communication.group.Group]: - moe_group = moe_group.lower() - assert moe_group in {"data", "dp", "dummy"}, f"moe-group not supported, got: {moe_group}" - logger.info(f"using moe-group: {moe_group}") - if not hasattr(dist.fleet.fleet, "_hcg"): - assert moe_group in {"dummy"}, "only support dummy gate in `single-model`" - if moe_group in {"data", "dp"}: - moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() - elif moe_group in {"dummy"}: - dummy_group = dist.communication.group.Group(0, None, [0]) - moe_group = dummy_group - else: - moe_group = dist.communication.group._get_global_group() # None 为全局通信组 - - return moe_group - def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): assert ( moe_num_experts >= expert_parallel_degree @@ -225,8 +210,8 @@ def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): assert ( moe_num_experts % expert_parallel_degree == 0 ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0" - moe_world_size_per_device = moe_num_experts // expert_parallel_degree - return moe_world_size_per_device + moe_num_experts_per_device = moe_num_experts // expert_parallel_degree + return moe_num_experts_per_device def _post_init(self): for p in self.gate.parameters(): @@ -240,7 +225,9 @@ def _post_init(self): # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") def expert_forward(self, dispatched_input): - true_experts = self.experts[self.rank * self.num_local_experts : (self.rank + 1) * self.num_local_experts] + true_experts = self.experts[ + self.moe_rank * self.moe_num_experts_per_device : (self.moe_rank + 1) * self.moe_num_experts_per_device + ] expert_outputs = [] chunks = dispatched_input.unbind(1) assert len(chunks) == len(true_experts), (len(chunks), len(true_experts)) @@ -274,7 +261,7 @@ def forward( capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate(reshaped_input) - print(f"capacity={capacity}") + # print(f"capacity={capacity}") # self.l_aux, combine_weights, dispatch_mask, self.exp_counts = # self.l_aux : # combine_weights : sec @@ -283,16 +270,20 @@ def forward( dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) if self.expert_parallel_degree > 1: - dispatched_input = _AllToAll.apply(dispatched_input, self.group) + dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group) # Re-shape after all-to-all: ecm -> gecm - dispatched_input = dispatched_input.reshape([self.expert_parallel_degree, self.num_local_experts, -1, d_model]) + dispatched_input = dispatched_input.reshape( + [self.expert_parallel_degree, self.moe_num_experts_per_device, -1, d_model] + ) expert_output = self.expert_forward(dispatched_input) # Re-shape before drop_tokens: gecm -> ecm - expert_output = expert_output.reshape([self.expert_parallel_degree * self.num_local_experts, -1, d_model]) + expert_output = expert_output.reshape( + [self.expert_parallel_degree * self.moe_num_experts_per_device, -1, d_model] + ) if self.expert_parallel_degree > 1: - expert_output = _AllToAll.apply(expert_output, self.group) + expert_output = _AllToAll.apply(expert_output, self.moe_group) # Re拿到不同device上的expert计算结果 combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) From ff930128d973bbdcfd7db3cfa2b4d865cf9c8035 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Wed, 6 Nov 2024 02:12:38 +0000 Subject: [PATCH 13/20] update --- paddlenlp/transformers/moe_gate.py | 6 +++++- paddlenlp/transformers/qwen2_moe/modeling.py | 16 +++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index e6103a617bce..0ba5239dda6b 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -167,6 +167,7 @@ def __init__(self, num_experts, expert_hidden_size, **kwargs): self.drop_policy = kwargs.pop("drop_policy", "probs") self.top_k = kwargs.pop("top_k", 2) + self.norm_topk_prob = kwargs.pop("norm_topk_prob", False) def topk_navie(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: """_summary_ @@ -406,6 +407,8 @@ def topkgating( raise ValueError(f"Invalid drop_policy: {self.drop_policy}") else: # Do not drop tokens - set capacity according to current expert assignments + locations = paddle.cumsum(mask, axis=0) - 1 + new_capacity = paddle.max(exp_counts) if self.group is not None: dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) @@ -415,7 +418,8 @@ def topkgating( gates_masked = gates * mask gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) - gates_masked = gates_masked / denom_s + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s # dispatch_mask locations_sc = self._one_hot_to_float(locations * mask, num_classes=capacity) diff --git a/paddlenlp/transformers/qwen2_moe/modeling.py b/paddlenlp/transformers/qwen2_moe/modeling.py index 36f71ceab2f9..5607af83dce0 100644 --- a/paddlenlp/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/transformers/qwen2_moe/modeling.py @@ -691,7 +691,7 @@ def __init__(self, num_experts, expert_hidden_size, **kwargs): # [hidden_size, n_expert] self.weight = paddle.create_parameter( shape=[expert_hidden_size, num_experts], - dtype=paddle.float32, + dtype=paddle.get_default_dtype(), is_bias=False, default_initializer=nn.initializer.Constant(1.0), ) @@ -705,11 +705,10 @@ def forward(self, hidden_states): # compute gating score hidden_states = hidden_states.reshape([-1, h_dim]) + logits = F.linear(hidden_states, self.weight, None) with paddle.amp.auto_cast(False): - logits = F.linear(hidden_states.cast(paddle.float32), self.weight, None) - - scores = self.gate_score_func(logits=logits) + scores = self.gate_score_func(logits=logits.cast(paddle.float32)) # topk_weight, topk_idx = self.topk_navie(scores, k=2) # topk_weight, topk_idx = self.topk_group(scores, k=2, n_group=4, topk_group=2) @@ -722,12 +721,19 @@ def forward(self, hidden_states): class Qwen2MoeSparseMoEBlock(MoELayer): def __init__(self, config: Qwen2MoeConfig): + gate = Qwen2MoeGate( + config.num_experts, + config.hidden_size, + top_k=config.num_experts_per_tok, + drop_tokens=False, + ) + super().__init__( config, moe_num_experts=config.num_experts, expert_class=Qwen2MoeMLP, expert_kwargs=config, - gate=Qwen2MoeGate(config.num_experts, config.hidden_size), + gate=gate, capacity=2.0, ) From de2d2576719e994d1a5ae824f739718a0d393113 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Thu, 7 Nov 2024 14:59:49 +0800 Subject: [PATCH 14/20] update --- paddlenlp/trainer/training_args.py | 8 ++ paddlenlp/transformers/moe_gate.py | 99 +++++--------------- paddlenlp/transformers/moe_layer.py | 6 +- paddlenlp/transformers/qwen2_moe/modeling.py | 89 +++++++++++++++--- 4 files changed, 110 insertions(+), 92 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 7cdca6572c4f..693a9d1a6993 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -882,6 +882,14 @@ class TrainingArguments: default=False, metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"}, ) + expert_max_capacity: Optional[int] = field( + default=pow(2, 32), + metadata={"help": "Enable MoE (Mixture of Experts) expert max token capacity"}, + ) + expert_min_capacity: Optional[int] = field( + default=1, + metadata={"help": "Enable MoE (Mixture of Experts) expert min token capacity"}, + ) release_grads: Optional[bool] = field( default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."} ) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index 0ba5239dda6b..64a2ecf55e9a 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -59,7 +59,7 @@ def uniform_sample(self, logits: paddle.Tensor) -> paddle.Tensor: def _one_hot_to_float(self, x, num_classes): if x.dtype not in (paddle.int32, paddle.int64): x = paddle.cast(x, paddle.int64) - return F.one_hot(x, num_classes=num_classes).cast(paddle.float32) + return F.one_hot(x, num_classes=num_classes).cast(paddle.get_default_dtype()) @paddle.no_grad() def _one_hot_to_int64(self, x, num_classes): @@ -68,7 +68,9 @@ def _one_hot_to_int64(self, x, num_classes): return F.one_hot(x, num_classes=num_classes).cast(paddle.int64) @paddle.no_grad() - def _capacity(self, gates: paddle.Tensor, capacity_factor: float, min_capacity: int) -> paddle.Tensor: + def _capacity( + self, gates: paddle.Tensor, capacity_factor: float, max_capacity: int, min_capacity: int + ) -> paddle.Tensor: """Calculate the capacity for each expert based on the gates and capacity factor. Args: @@ -87,6 +89,8 @@ def _capacity(self, gates: paddle.Tensor, capacity_factor: float, min_capacity: capacity = int((num_tokens // num_experts) * capacity_factor) if capacity < min_capacity: capacity = min_capacity + if capacity > max_capacity: + capacity = max_capacity assert capacity > 0, f"requires capacity > 0, capacity_factor: {capacity_factor}, input_shape: {gates.shape}" return capacity @@ -96,8 +100,8 @@ def _cal_aux_loss(self, gates, mask): 计算辅助损失 Args: - gates (paddle.Tensor): 表示每个expert的输出概率。形状为[batch_size,num_experts] - mask (paddle.Tensor): 表示每个样本是否属于某个expert。形状为[batch_size,num_experts] + gates (paddle.Tensor): 表示每个expert的输出概率。形状为[batch_size, num_experts] + mask (paddle.Tensor): 表示每个样本是否属于某个expert。形状为[batch_size, num_experts] Returns: paddle.Tensor: 辅助损失值。 @@ -140,9 +144,11 @@ def _cal_orthogonal_loss(self) -> paddle.Tensor: class PretrainedMoEGate(nn.Layer, MoEGateMixin): - def __init__(self, num_experts, expert_hidden_size, **kwargs): + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): super(PretrainedMoEGate, self).__init__() + self.config = config + self.num_experts = num_experts self.expert_hidden_size = expert_hidden_size @@ -152,6 +158,7 @@ def __init__(self, num_experts, expert_hidden_size, **kwargs): self.capacity_factor = kwargs.pop("capacity_factor", 1.0) self.eval_capacity_factor = kwargs.pop("eval_capacity_factor", 1.0) self.min_capacity = kwargs.pop("min_capacity", 1.0) + self.max_capacity = kwargs.pop("max_capacity", pow(2, 32)) self.group = kwargs.pop("group", None) self.global_aux_loss = kwargs.pop("global_aux_loss", False) @@ -169,7 +176,7 @@ def __init__(self, num_experts, expert_hidden_size, **kwargs): self.top_k = kwargs.pop("top_k", 2) self.norm_topk_prob = kwargs.pop("norm_topk_prob", False) - def topk_navie(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + def topk_naive(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: """_summary_ Args: @@ -226,7 +233,7 @@ def top1gating( logits += self.gumbel_rsample(logits.shape) gates = self.gate_score_func(logits=logits) - capacity = self._capacity(gates, self.capacity_factor, self.min_capacity) + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) # Create a mask for 1st's expert per token # noisy gating @@ -235,7 +242,7 @@ def top1gating( # mask only used tokens if used_token is not None: - mask1 = paddle.einsum("s,se->se", used_token, mask1) # 将used_token与mask1进行逐元素相乘,得到新的mask1 + mask1 = paddle.einsum("s,se->se", used_token, mask1) # 将used_token与mask1进行逐元素相乘,得到新的mask1 # gating decisions exp_counts = paddle.sum(mask1, axis=0) # 计算每个专家的token数量 @@ -264,7 +271,7 @@ def top1gating( _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # 选择top_capacity个token - # 将mask1中的元素与top_idx进行逐元素相乘,得到新的mask1 + # 将mask1中的元素与top_idx进行逐元素相乘,得到新的mask1 new_mask1 = mask1 * paddle.zeros_like(mask1).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=0) mask1 = new_mask1 @@ -290,14 +297,14 @@ def top2gating( ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Args: - logits: [S, E],形状为 [seq_len, num_experts],用于计算top2 gate。 + logits: [S, E],形状为 [seq_len, num_experts],用于计算top2 gate。 cap: 表示每个token可以分发的最大数量的超参数。 Returns: tuple: - capacity: 每个token可分发的最大数量。 - dispatch_masks: 用于dispatching的mask。 - - combine_weights:用于combining的权重。 + - combine_weights: 用于combining的权重。 - scatter_indexes: 用于scattering的索引。 - loss_aux: aux loss。 - loss_z: z loss。 @@ -337,7 +344,7 @@ def top2gating( exp_counts = paddle.sum(mask1 + mask2, axis=0) if self.drop_tokens: # Calculate configured capacity and remove locations outside capacity from mask - capacity = self._capacity(gates, self.capacity_factor, self.min_capacity) + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) # Remove locations outside capacity from mask. mask1 *= (locations1 < capacity).cast(paddle.int64) mask2 *= (locations2 < capacity).cast(paddle.int64) @@ -380,17 +387,19 @@ def topkgating( gates: paddle.Tensor, ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Implements TopKGating on logits.""" + l_zloss = self._cal_z_loss(gates) + # get topk gates top_gate, top_idx = paddle.topk(gates, k=self.top_k, axis=1) # get topk mask mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) - exp_counts = paddle.sum(mask, axis=0) - l_aux = self._cal_aux_loss(gates, mask) + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + if self.drop_tokens: # Calculate configured capacity and remove locations outside capacity from mask - capacity = self._capacity(gates, self.capacity_factor * self.top_k, self.min_capacity) + capacity = self._capacity(gates, self.capacity_factor * self.top_k, self.max_capacity, self.min_capacity) # update mask and locations by capacity if self.drop_policy == "probs": @@ -426,62 +435,4 @@ def topkgating( combine_weights = paddle.einsum("se,sec->sec", gates_masked, locations_sc) dispatch_mask = combine_weights.cast(paddle.bool) - return capacity, combine_weights, dispatch_mask, exp_counts, l_aux - - def forward(self, hidden_states): - raise NotImplementedError("Please implement the forward function.") - - -class TopKGate(PretrainedMoEGate): - def __init__( - self, - num_experts, - expert_hidden_size, - weight_attr=None, - bias_attr=None, - top_k=2, - capacity_factor=1.0, - eval_capacity_factor=1.0, - scoring_func="softmax", - scaling_attr=None, - ): - super().__init__(num_experts, expert_hidden_size, weight_attr, bias_attr) - self.top_k = top_k - self.scoring_func = scoring_func - self.scaling_attr = scaling_attr - - def forward( - self, - hidden_states: paddle.Tensor, - used_token: paddle.Tensor = None, - ): - bsz, seq_len, hidden_size = hidden_states.shape - hidden_states = hidden_states.reshape([-1, hidden_size]) - logits = F.linear(x=paddle.cast(hidden_states, paddle.float32), weight=self.weight, bias=self.bias) - if self.top_k == 1: - gate_output = self.top1gating( - logits, - self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity, - used_token, - self.noisy_gate_policy if self.training else None, - self.drop_tokens, - ) - elif self.top_k == 2: - gate_output = self.top2gating( - logits, - self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity, - self.drop_tokens, - self.top2_2nd_expert_sampling, - ) - else: - gate_output = self.topkgating( - logits, - self.top_k, - self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity, - self.drop_tokens, - ) - - return gate_output + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 96ed534fe99c..fa1f71b7b72f 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -261,8 +261,6 @@ def forward( capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate(reshaped_input) - # print(f"capacity={capacity}") - # self.l_aux, combine_weights, dispatch_mask, self.exp_counts = # self.l_aux : # combine_weights : sec # dispatch_mask : sec @@ -285,9 +283,9 @@ def forward( if self.expert_parallel_degree > 1: expert_output = _AllToAll.apply(expert_output, self.moe_group) - # Re拿到不同device上的expert计算结果 + # combine withe expert weights combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) a = combined_output.reshape(hidden_state.shape) - return a + return a, l_aux, l_zloss diff --git a/paddlenlp/transformers/qwen2_moe/modeling.py b/paddlenlp/transformers/qwen2_moe/modeling.py index 5607af83dce0..d91ff3fbf705 100644 --- a/paddlenlp/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/transformers/qwen2_moe/modeling.py @@ -686,8 +686,8 @@ def forward( class Qwen2MoeGate(PretrainedMoEGate): - def __init__(self, num_experts, expert_hidden_size, **kwargs): - super().__init__(num_experts, expert_hidden_size, **kwargs) + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): + super().__init__(config, num_experts, expert_hidden_size, **kwargs) # [hidden_size, n_expert] self.weight = paddle.create_parameter( shape=[expert_hidden_size, num_experts], @@ -704,17 +704,13 @@ def forward(self, hidden_states): _, h_dim = hidden_states.shape # compute gating score - hidden_states = hidden_states.reshape([-1, h_dim]) logits = F.linear(hidden_states, self.weight, None) with paddle.amp.auto_cast(False): - scores = self.gate_score_func(logits=logits.cast(paddle.float32)) + scores = self.gate_score_func(logits=logits) + scores = scores.cast(paddle.get_default_dtype()) - # topk_weight, topk_idx = self.topk_navie(scores, k=2) - # topk_weight, topk_idx = self.topk_group(scores, k=2, n_group=4, topk_group=2) - - capacity, combine_weights, dispatch_mask, exp_counts, l_aux = self.topkgating(scores) - l_zloss = self._cal_z_loss(logits) + capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss @@ -722,6 +718,7 @@ def forward(self, hidden_states): class Qwen2MoeSparseMoEBlock(MoELayer): def __init__(self, config: Qwen2MoeConfig): gate = Qwen2MoeGate( + config, config.num_experts, config.hidden_size, top_k=config.num_experts_per_tok, @@ -744,13 +741,77 @@ def __init__(self, config: Qwen2MoeConfig): self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias_attr=False) def forward(self, hidden_states): - final_hidden_states = super().forward(hidden_states) + final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) + + # shared_expert_output = self.shared_expert(hidden_states) + # shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + # final_hidden_states = final_hidden_states + shared_expert_output + + return final_hidden_states, l_aux + + +class Qwen2MoeSparseMoEBlock_OLD(nn.Layer): + def __init__(self, config: Qwen2MoeConfig): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + self.gate = nn.Linear(config.hidden_size, self.num_experts, bias_attr=False) + self.experts = nn.LayerList([Qwen2MoeMLP(config) for _ in range(self.num_experts)]) + + self.shared_expert = Qwen2MoeMLP(config, is_shared=True) + self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias_attr=False) + + def forward(self, hidden_states): + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states = hidden_states.reshape([-1, hidden_dim]) + # router_logits: [batch_size * seq_len, num_experts] + router_logits = self.gate(hidden_states) + + with paddle.amp.auto_cast(False): + routing_weights = F.softmax(router_logits.astype("float32"), axis=1) + routing_weights, selected_experts = paddle.topk(routing_weights, self.top_k, axis=-1) + if self.norm_topk_prob: # Note: Mixtral is set norm as default, Qwen2Moe is set to no norm + routing_weights /= routing_weights.sum(axis=-1, keepdim=True) + # we cast back to input dtype + routing_weights = routing_weights.astype(hidden_states.dtype) + + final_hidden_states = paddle.zeros( + [batch_size * seq_len, hidden_dim], + dtype=hidden_states.dtype, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be solicited. + # shape: [num_experts, top_k, batch_size * seq_len] + expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).transpose([2, 1, 0]) + + # Loop over all available experts in the model and perform the computation on each expert. + for expert_id in range(self.num_experts): + expert_layer = self.experts[expert_id] + idx, top_x = paddle.where(expert_mask[expert_id]) + + if top_x.shape[0] == 0: + continue + + current_state = paddle.gather(hidden_states, top_x.squeeze()) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx] + + top_x = top_x.squeeze() + if top_x.shape == []: + top_x = paddle.to_tensor([top_x.item()]) + final_hidden_states = paddle.index_add_( + final_hidden_states, top_x, 0, current_hidden_states.astype(hidden_states.dtype) + ) + + # shared_expert_output = self.shared_expert(hidden_states) + # shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output - shared_expert_output = self.shared_expert(hidden_states) - shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output - final_hidden_states = final_hidden_states + shared_expert_output + # final_hidden_states = final_hidden_states + shared_expert_output - return final_hidden_states, 0.0 + final_hidden_states = final_hidden_states.reshape([batch_size, seq_len, hidden_dim]) + return final_hidden_states, router_logits class Qwen2MoeDecoderLayer(nn.Layer): From 83bdeddaff6a7d9e05c6b4986f074c1430ee1aa3 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Thu, 7 Nov 2024 10:22:49 +0000 Subject: [PATCH 15/20] update token_priority method --- paddlenlp/transformers/moe_gate.py | 70 ++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 13 deletions(-) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index 64a2ecf55e9a..b88a220901d9 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -176,6 +176,56 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs): self.top_k = kwargs.pop("top_k", 2) self.norm_topk_prob = kwargs.pop("norm_topk_prob", False) + def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: + """_summary_ + The priority is the cumulative sum of the expert indices. + + This method is used in hunyuan model + Args: + topk_idx (paddle.Tensor): [batch_size * seq_len, topk] + + Returns: + paddle.Tensor: cumsum locations + """ + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = paddle.transpose(topk_idx, [1, 0]) # [topk, B*S] + # Shape: [num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape([-1]) + + # Create mask out of indices. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_index, self.num_experts).cast(paddle.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + token_priority = paddle.cumsum(expert_mask, axis=0) * expert_mask - 1 + # Shape: [num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((self.top_k, -1, self.num_experts)) + # Shape: [tokens_per_group, num_selected_experts, num_experts]. + token_priority = paddle.transpose(token_priority, [1, 0, 2]) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [tokens_per_group, num_experts]. + token_priority = paddle.max(token_priority, axis=1) + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [tokens_per_group, num_experts, expert_capacity]. + valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity) + token_priority = paddle.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, capacity) + valid_mask = valid_mask.unsqueeze(-1).expand(valid_mask.shape + [capacity]) + dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0) + + return dispatch_mask + def topk_naive(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: """_summary_ @@ -405,23 +455,19 @@ def topkgating( if self.drop_policy == "probs": topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) - capacity_mask = paddle.zeros_like(gates).put_along_axis(capacity_indices, paddle.to_tensor(1.0), axis=0) # fmt:skip - mask = mask * capacity_mask - locations = paddle.cumsum(mask, axis=0) - 1 + token_priority = self._priority(capacity_indices, capacity) elif self.drop_policy == "position": - locations = paddle.cumsum(mask, axis=0) - 1 - mask *= (locations < capacity).cast(paddle.int64) + token_priority = self._priority(top_idx, capacity) else: raise ValueError(f"Invalid drop_policy: {self.drop_policy}") else: # Do not drop tokens - set capacity according to current expert assignments - locations = paddle.cumsum(mask, axis=0) - 1 - - new_capacity = paddle.max(exp_counts) + local_capacity = paddle.max(exp_counts) if self.group is not None: - dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) - capacity = int(new_capacity) + dist.all_reduce(local_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(local_capacity) + token_priority = self._priority(top_idx, capacity) # normalize gates gates_masked = gates * mask @@ -430,9 +476,7 @@ def topkgating( if self.norm_topk_prob: gates_masked = gates_masked / denom_s - # dispatch_mask - locations_sc = self._one_hot_to_float(locations * mask, num_classes=capacity) - combine_weights = paddle.einsum("se,sec->sec", gates_masked, locations_sc) + combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority) dispatch_mask = combine_weights.cast(paddle.bool) return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss From 63f67552536cf2abc697db5167890b91da985361 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Thu, 7 Nov 2024 13:20:04 +0000 Subject: [PATCH 16/20] update data type --- paddlenlp/transformers/moe_gate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index b88a220901d9..adf830eeb7f7 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -220,7 +220,7 @@ def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: # Shape: [tokens_per_group, num_experts, expert_capacity]. valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity) token_priority = paddle.masked_fill(token_priority, ~valid_mask, 0) - dispatch_mask = F.one_hot(token_priority, capacity) + dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.int32) valid_mask = valid_mask.unsqueeze(-1).expand(valid_mask.shape + [capacity]) dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0) @@ -476,7 +476,7 @@ def topkgating( if self.norm_topk_prob: gates_masked = gates_masked / denom_s - combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority) + combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype())) dispatch_mask = combine_weights.cast(paddle.bool) return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss From 17537b3ecd16fcbbb77c1cacc1f1ffb5df995e80 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Fri, 8 Nov 2024 14:13:01 +0800 Subject: [PATCH 17/20] remove old moe --- paddlenlp/transformers/moe_layer.py | 4 +- paddlenlp/transformers/qwen2_moe/modeling.py | 70 +------------------- 2 files changed, 4 insertions(+), 70 deletions(-) diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index fa1f71b7b72f..ec7e9f498c62 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy from typing import Any, Tuple import paddle @@ -194,10 +193,9 @@ def __init__( self.enable_recompute = False self.experts = nn.LayerList([]) - expert = expert_class(expert_kwargs) for i in range(self.moe_num_experts): if i // self.moe_num_experts_per_device == self.moe_rank: - self.experts.append(deepcopy(expert)) + self.experts.append(expert_class(expert_kwargs)) else: self.experts.append(None) diff --git a/paddlenlp/transformers/qwen2_moe/modeling.py b/paddlenlp/transformers/qwen2_moe/modeling.py index d91ff3fbf705..18507c1d5dc7 100644 --- a/paddlenlp/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/transformers/qwen2_moe/modeling.py @@ -743,77 +743,13 @@ def __init__(self, config: Qwen2MoeConfig): def forward(self, hidden_states): final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) - # shared_expert_output = self.shared_expert(hidden_states) - # shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output - # final_hidden_states = final_hidden_states + shared_expert_output + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + final_hidden_states = final_hidden_states + shared_expert_output return final_hidden_states, l_aux -class Qwen2MoeSparseMoEBlock_OLD(nn.Layer): - def __init__(self, config: Qwen2MoeConfig): - super().__init__() - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - - self.gate = nn.Linear(config.hidden_size, self.num_experts, bias_attr=False) - self.experts = nn.LayerList([Qwen2MoeMLP(config) for _ in range(self.num_experts)]) - - self.shared_expert = Qwen2MoeMLP(config, is_shared=True) - self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias_attr=False) - - def forward(self, hidden_states): - batch_size, seq_len, hidden_dim = hidden_states.shape - hidden_states = hidden_states.reshape([-1, hidden_dim]) - # router_logits: [batch_size * seq_len, num_experts] - router_logits = self.gate(hidden_states) - - with paddle.amp.auto_cast(False): - routing_weights = F.softmax(router_logits.astype("float32"), axis=1) - routing_weights, selected_experts = paddle.topk(routing_weights, self.top_k, axis=-1) - if self.norm_topk_prob: # Note: Mixtral is set norm as default, Qwen2Moe is set to no norm - routing_weights /= routing_weights.sum(axis=-1, keepdim=True) - # we cast back to input dtype - routing_weights = routing_weights.astype(hidden_states.dtype) - - final_hidden_states = paddle.zeros( - [batch_size * seq_len, hidden_dim], - dtype=hidden_states.dtype, - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be solicited. - # shape: [num_experts, top_k, batch_size * seq_len] - expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).transpose([2, 1, 0]) - - # Loop over all available experts in the model and perform the computation on each expert. - for expert_id in range(self.num_experts): - expert_layer = self.experts[expert_id] - idx, top_x = paddle.where(expert_mask[expert_id]) - - if top_x.shape[0] == 0: - continue - - current_state = paddle.gather(hidden_states, top_x.squeeze()) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx] - - top_x = top_x.squeeze() - if top_x.shape == []: - top_x = paddle.to_tensor([top_x.item()]) - final_hidden_states = paddle.index_add_( - final_hidden_states, top_x, 0, current_hidden_states.astype(hidden_states.dtype) - ) - - # shared_expert_output = self.shared_expert(hidden_states) - # shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output - - # final_hidden_states = final_hidden_states + shared_expert_output - - final_hidden_states = final_hidden_states.reshape([batch_size, seq_len, hidden_dim]) - return final_hidden_states, router_logits - - class Qwen2MoeDecoderLayer(nn.Layer): def __init__(self, config: Qwen2MoeConfig, layerwise_recompute: bool = False): super().__init__() From 2b0bf16dbdbbca096c30200083519190004b07d9 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Mon, 11 Nov 2024 15:53:00 +0800 Subject: [PATCH 18/20] fix moe capacity reduce.Max --- paddlenlp/transformers/moe_layer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index ec7e9f498c62..f577b2b55c42 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -22,6 +22,8 @@ from paddle.distributed.communication import stream from paddle.distributed.communication.group import Group +from .moe_gate import PretrainedMoEGate + def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): """ @@ -162,7 +164,7 @@ def __init__( moe_num_experts: int, expert_class: nn.Layer, expert_kwargs: dict, - gate: nn.Layer, + gate: PretrainedMoEGate, capacity: int = 1.0, moe_group: str = "data", all_to_all_dropout=0.0, @@ -174,7 +176,7 @@ def __init__( self.moe_num_experts = moe_num_experts self.capacity = capacity - if dist.get_world_size() > 1: + if dist.get_world_size() > 1 and moe_group == "data": self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() self.moe_rank = dist.get_rank(self.moe_group) self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank @@ -184,6 +186,7 @@ def __init__( self.moe_num_experts, self.expert_parallel_degree ) else: + # when moe_group is dummy, we don't need to use all_to_all self.moe_group = None self.moe_rank = 0 self.expert_parallel_degree = 1 @@ -200,6 +203,7 @@ def __init__( self.experts.append(None) self.gate = gate + self.gate.group = self.moe_group def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): assert ( From 801f0ff1d8d9bbed5007f9082f0ab2a45f80dedf Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Wed, 13 Nov 2024 11:39:17 +0000 Subject: [PATCH 19/20] update comment --- .../transformers/deepseek_v2/modeling.py | 2 + paddlenlp/transformers/moe_gate.py | 54 ++++++++----------- paddlenlp/transformers/moe_layer.py | 25 +-------- 3 files changed, 24 insertions(+), 57 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 933293fc6402..5c5c0b43b95a 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -18,6 +18,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Paddle DeepSeek model.""" +from __future__ import annotations + import math import warnings from functools import partial diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index adf830eeb7f7..e4e1fefdfb35 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -97,16 +97,16 @@ def _capacity( def _cal_aux_loss(self, gates, mask): """ - 计算辅助损失 + Calculate auxiliary loss Args: - gates (paddle.Tensor): 表示每个expert的输出概率。形状为[batch_size, num_experts] - mask (paddle.Tensor): 表示每个样本是否属于某个expert。形状为[batch_size, num_experts] + gates (paddle.Tensor): Represents the output probability of each expert. The shape is [batch_size, num_experts] + mask (paddle.Tensor): Represents whether each sample belongs to a certain expert. The shape is [batch_size, num_experts] Returns: - paddle.Tensor: 辅助损失值。 + paddle.Tensor: The value of auxiliary loss. - """ + """ me = paddle.mean(gates, axis=0) ce = paddle.mean(mask.cast("float32"), axis=0) if self.global_aux_loss: @@ -123,11 +123,13 @@ def _cal_aux_loss(self, gates, mask): def _cal_z_loss(self, logits) -> paddle.Tensor: """ - 计算z损失 + Calculate the z loss. + Args: - logits (paddle.paddle.Tensor): 模型输出。形状为[batch_size, num_experts] + logits (paddle.Tensor): Model output. The shape is [batch_size, num_experts]. + Returns: - paddle.paddle.Tensor: z损失值。 + paddle.Tensor: The z loss value. """ l_zloss = logits.exp().sum(1).log().square().mean() return l_zloss @@ -287,22 +289,24 @@ def top1gating( # Create a mask for 1st's expert per token # noisy gating - indices1_s = paddle.argmax(logits if self.noisy_gate_policy == "RSample" else gates, axis=1) # 仅保存最大值位置 - mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) # 将最大值位置转换为one-hot向量 [s, e] + # Only save the position of the maximum value + indices1_s = paddle.argmax(logits if self.noisy_gate_policy == "RSample" else gates, axis=1) + # Convert the position of the maximum value to a one-hot vector [s, e] + mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) # mask only used tokens if used_token is not None: - mask1 = paddle.einsum("s,se->se", used_token, mask1) # 将used_token与mask1进行逐元素相乘,得到新的mask1 + mask1 = paddle.einsum("s,se->se", used_token, mask1) # Element-wise multiply used_token with mask1 to obtain a new mask1 # gating decisions - exp_counts = paddle.sum(mask1, axis=0) # 计算每个专家的token数量 + exp_counts = paddle.sum(mask1, axis=0) # Calculate the number of tokens for each expert # if we don't want to drop any tokens if not self.drop_tokens: - new_capacity = paddle.max(exp_counts) # 计算每个专家的token数量 + new_capacity = paddle.max(exp_counts) # Calculate the number of tokens for each expert # Communicate across expert processes to pick the maximum capacity. if self.group is not None: - dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) # 在专家进程之间进行最大值计算 + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) # Calculate the maximum value among expert processes # Make sure the capacity value does not exceed the number of tokens. capacity = int(min(new_capacity, paddle.tensor(mask1.size(0)))) @@ -319,17 +323,16 @@ def top1gating( logits.shape[0] >= self.min_capacity ), "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." - _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # 选择top_capacity个token + _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # Select top_capacity tokens - # 将mask1中的元素与top_idx进行逐元素相乘,得到新的mask1 new_mask1 = mask1 * paddle.zeros_like(mask1).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=0) mask1 = new_mask1 # Compute locations in capacity buffer - locations1 = paddle.cumsum(mask1, axis=0) - 1 # 计算每个token在mask1中的位置 + locations1 = paddle.cumsum(mask1, axis=0) - 1 # Compute the position of each token in mask1 # Store the capacity location for each token - locations1_s = paddle.sum(locations1 * mask1, axis=1).cast(paddle.int64) # 计算每个token在mask1中的位置 + locations1_s = paddle.sum(locations1 * mask1, axis=1).cast(paddle.int64) # Normalize gate probabilities mask1_float = mask1.cast(paddle.float32) @@ -345,21 +348,6 @@ def top2gating( self, logits: paddle.Tensor, ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """ - Args: - logits: [S, E],形状为 [seq_len, num_experts],用于计算top2 gate。 - cap: 表示每个token可以分发的最大数量的超参数。 - - Returns: - tuple: - - capacity: 每个token可分发的最大数量。 - - dispatch_masks: 用于dispatching的mask。 - - combine_weights: 用于combining的权重。 - - scatter_indexes: 用于scattering的索引。 - - loss_aux: aux loss。 - - loss_z: z loss。 - """ - """Implements Top2Gating on logits.""" # everything is in fp32 in this function gates = self.gate_score_func(logits=logits) diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index f577b2b55c42..56369c6c3b92 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from typing import Any, Tuple @@ -134,30 +135,6 @@ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]: class MoELayer(nn.Layer): - """MOELayer module which implements MixtureOfExperts as described in Gshard_. - :: - - gate = Top2Gate(model_dim, num_experts) - - moe = MoELayer(gate, expert) - output = moe(input) - l_aux = moe.l_aux - - .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf - - Args: - gate (paddle.nn.Layer): - gate network - expert (paddle.nn.LayerList): - expert network, LayerList 长度是 per_device 上的 expert 数。 - group (paddle.ProgressGroup) - recompute: 启用MOE内recomupte - Returns: - output - combine_weight - router-loss - """ - def __init__( self, config, From 358483bf77a0acac06628f20f9834c7dce582cbe Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Fri, 15 Nov 2024 06:24:09 +0000 Subject: [PATCH 20/20] lint --- paddlenlp/transformers/moe_gate.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index e4e1fefdfb35..8118ba60f7ac 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -106,7 +106,7 @@ def _cal_aux_loss(self, gates, mask): Returns: paddle.Tensor: The value of auxiliary loss. - """ + """ me = paddle.mean(gates, axis=0) ce = paddle.mean(mask.cast("float32"), axis=0) if self.global_aux_loss: @@ -124,10 +124,10 @@ def _cal_aux_loss(self, gates, mask): def _cal_z_loss(self, logits) -> paddle.Tensor: """ Calculate the z loss. - + Args: logits (paddle.Tensor): Model output. The shape is [batch_size, num_experts]. - + Returns: paddle.Tensor: The z loss value. """ @@ -292,11 +292,13 @@ def top1gating( # Only save the position of the maximum value indices1_s = paddle.argmax(logits if self.noisy_gate_policy == "RSample" else gates, axis=1) # Convert the position of the maximum value to a one-hot vector [s, e] - mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) + mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) # mask only used tokens if used_token is not None: - mask1 = paddle.einsum("s,se->se", used_token, mask1) # Element-wise multiply used_token with mask1 to obtain a new mask1 + mask1 = paddle.einsum( + "s,se->se", used_token, mask1 + ) # Element-wise multiply used_token with mask1 to obtain a new mask1 # gating decisions exp_counts = paddle.sum(mask1, axis=0) # Calculate the number of tokens for each expert @@ -306,7 +308,9 @@ def top1gating( new_capacity = paddle.max(exp_counts) # Calculate the number of tokens for each expert # Communicate across expert processes to pick the maximum capacity. if self.group is not None: - dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) # Calculate the maximum value among expert processes + dist.all_reduce( + new_capacity, op=dist.ReduceOp.MAX, group=self.group + ) # Calculate the maximum value among expert processes # Make sure the capacity value does not exceed the number of tokens. capacity = int(min(new_capacity, paddle.tensor(mask1.size(0))))