From 4a981faac94e305812dc3bd4bf82bb710dd530d1 Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Tue, 10 Oct 2023 19:36:57 +0800 Subject: [PATCH 1/3] reconstruct fused_transformer_layers --- .../transformers/bloom/modeling.py | 8 +- .../transformers/chatglm/modeling.py | 7 +- .../transformers/fused_transformer_layers.py | 756 +++++++++++++++++- .../experimental/transformers/gpt/modeling.py | 6 +- .../transformers/llama/modeling.py | 12 +- .../experimental/transformers/opt/modeling.py | 7 +- tests/llm/test_predictor.py | 24 +- 7 files changed, 801 insertions(+), 19 deletions(-) diff --git a/paddlenlp/experimental/transformers/bloom/modeling.py b/paddlenlp/experimental/transformers/bloom/modeling.py index 8825af26d5ad..0c37690dc2b1 100644 --- a/paddlenlp/experimental/transformers/bloom/modeling.py +++ b/paddlenlp/experimental/transformers/bloom/modeling.py @@ -21,7 +21,8 @@ from paddlenlp_ops import get_padding_offset from paddlenlp.experimental.transformers.fused_transformer_layers import ( - FusedMultiTransformer, + FusedMultiTransformerBase, + FusedMultiTransformerConfig, ) from paddlenlp.experimental.transformers.generation_utils import ( GenerationInferenceModel, @@ -112,7 +113,8 @@ def __init__(self, config): ffn1_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn1_bias".format(i)) for i in range(config.n_layer)] ffn2_weight_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_weight".format(i)) for i in range(config.n_layer)] ffn2_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_bias".format(i)) for i in range(config.n_layer)] - self.transformer_block = FusedMultiTransformer( + + transformer_config = FusedMultiTransformerConfig( self.embed_dim, self.n_head, 4 * self.embed_dim, @@ -133,6 +135,8 @@ def __init__(self, config): ffn2_weight_attrs=ffn2_weight_attrs, ffn2_bias_attrs=ffn2_bias_attrs, ) + + self.transformer_block = FusedMultiTransformerBase(transformer_config) self.cache_kvs = [] # Final Layer Norm diff --git a/paddlenlp/experimental/transformers/chatglm/modeling.py b/paddlenlp/experimental/transformers/chatglm/modeling.py index c87fa8f8c9f2..574c68ab7207 100644 --- a/paddlenlp/experimental/transformers/chatglm/modeling.py +++ b/paddlenlp/experimental/transformers/chatglm/modeling.py @@ -20,7 +20,8 @@ from paddlenlp_ops import get_padding_offset from paddlenlp.experimental.transformers.fused_transformer_layers import ( - FusedMultiTransformer, + FusedMultiTransformerBase, + FusedMultiTransformerConfig, ) from paddlenlp.experimental.transformers.generation_utils import ( GenerationInferenceModel, @@ -183,7 +184,8 @@ def __init__(self, config: ChatGLMConfig): ] ffn2_bias_attrs = [paddle.ParamAttr(name="fusemt.{}.ffn2_bias".format(i)) for i in range(config.num_layers)] alpha = (2 * self.config.num_hidden_layers) ** 0.5 - self.transformer_block = FusedMultiTransformer( + + transformer_config = FusedMultiTransformerConfig( config.hidden_size, config.num_attention_heads, 4 * config.hidden_size, @@ -209,6 +211,7 @@ def __init__(self, config: ChatGLMConfig): norm_type="layernorm", use_neox_rotary_style=True, ) + self.transformer_block = FusedMultiTransformerBase(transformer_config) def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 796bf02ec289..0dfd9143c159 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -44,7 +44,13 @@ ) -__all__ = ["FusedMultiTransformer"] +__all__ = [ + "FusedMultiTransformerConfig", + "FusedMultiTransformerBase", + "FusedMultiTransformerPostLayernorm", + "FusedMultiTransformerWeightOnly", + "FusedMultiTransformerWeightOnlyPostLayernorm", +] # for distributed tensor model parallel @@ -132,6 +138,742 @@ def fused_act_bias_wrapper( return out +class FusedMultiTransformerConfig: + def __init__( + self, + embed_dim, + num_heads, + dim_feedforward, + quant_bits=-1, # -1 means use Half precision. + dropout_rate=0.0, + activation="gelu", + norm_type="layernorm", + use_neox_rotary_style=False, + normalize_before=True, + ln_scale_attrs=None, + ln_bias_attrs=None, + qkv_weight_attrs=None, + qkv_weight_scale_attrs=None, + qkv_bias_attrs=None, + linear_weight_attrs=None, + linear_weight_scale_attrs=None, + linear_bias_attrs=None, + ffn_ln_scale_attrs=None, + ffn_ln_bias_attrs=None, + ffn1_weight_attrs=None, + ffn1_weight_scale_attrs=None, + ffn1_bias_attrs=None, + ffn2_weight_attrs=None, + ffn2_weight_scale_attrs=None, + ffn2_bias_attrs=None, + epsilon=1e-5, + residual_alpha=1.0, + num_layers=-1, + nranks=1, + trans_qkvw=True, + ring_id=-1, + ): + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dim_feedforward = dim_feedforward + self.quant_bits = quant_bits + self.dropout_rate = dropout_rate + self.activation = activation + self.norm_type = norm_type + + self.use_neox_rotary_style = use_neox_rotary_style + self.normalize_before = normalize_before + self.ln_scale_attrs = ln_scale_attrs + self.ln_bias_attrs = ln_bias_attrs + self.qkv_weight_attrs = qkv_weight_attrs + self.qkv_weight_scale_attrs = qkv_weight_scale_attrs + self.qkv_bias_attrs = qkv_bias_attrs + self.linear_weight_attrs = linear_weight_attrs + self.linear_weight_scale_attrs = linear_weight_scale_attrs + self.linear_bias_attrs = linear_bias_attrs + self.ffn_ln_scale_attrs = ffn_ln_scale_attrs + self.ffn_ln_bias_attrs = ffn_ln_bias_attrs + self.ffn1_weight_attrs = ffn1_weight_attrs + self.ffn1_weight_scale_attrs = ffn1_weight_scale_attrs + self.ffn1_bias_attrs = ffn1_bias_attrs + self.ffn2_weight_attrs = ffn2_weight_attrs + self.ffn2_weight_scale_attrs = ffn2_weight_scale_attrs + self.ffn2_bias_attrs = ffn2_bias_attrs + self.epsilon = epsilon + self.residual_alpha = residual_alpha + self.num_layers = num_layers + self.nranks = nranks + self.trans_qkvw = trans_qkvw + self.ring_id = ring_id + + +class FusedMultiTransformerBase(Layer): + def __init__(self, config: FusedMultiTransformerConfig): + super().__init__() + + assert config.embed_dim > 0, "Expected embed_dim to be greater than 0, " "but received {}".format( + config.embed_dim + ) + assert config.num_heads > 0, "Expected nhead to be greater than 0, " "but received {}".format(config.num_heads) + assert config.dim_feedforward > 0, "Expected dim_feedforward to be greater than 0, but received {}".format( + config.dim_feedforward + ) + + # self.normalize_before = normalize_before + self._dtype = self._helper.get_default_dtype() + self._epsilon = config.epsilon + self._residual_alpha = config.residual_alpha + self._trans_qkvw = config.trans_qkvw + self._ring_id = config.ring_id + self.nranks = config.nranks + self.norm_type = config.norm_type + if self.norm_type == "layernorm": + self.norm_func = fused_layer_norm + elif self.norm_type == "rmsnorm": + self.norm_func = fused_rms_norm + else: + raise NotImplementedError("Only support norm type of [layernorm, rmsnorm]") + self.use_neox_rotary_style = config.use_neox_rotary_style + self._norm_weight_dtype = "float32" if self.norm_type == "layernorm" else self._dtype + + self.activation = config.activation + + self.embed_dim = config.embed_dim + self.num_heads = config.num_heads + self.head_dim = config.embed_dim // config.num_heads + assert self.head_dim * config.num_heads == config.embed_dim, "embed_dim must be divisible by num_heads" + + # tensor model parallel + if config.nranks > 1: + assert config.ring_id != -1 + assert config.num_heads % config.nranks == 0 + assert config.dim_feedforward % config.nranks == 0 + num_heads = config.num_heads // config.nranks + dim_feedforward = config.dim_feedforward // config.nranks + self._dim_feedforward = dim_feedforward + + if isinstance(config.qkv_weight_attrs, (list, tuple)): + self.num_layers = len(config.qkv_weight_attrs) + assert self.num_layers > 0 + + self.weight_dtype = self._dtype + self.create_params_type = self.get_weight_create_dype() + + self.ln_scales, self.ln_biases = [], [] + self.qkv_weights, self.qkv_biases = [], [] + self.linear_weights, self.linear_biases = [], [] + self.ffn_ln_scales, self.ffn_ln_biases = [], [] + self.ffn1_weights, self.ffn1_biases = [], [] + self.ffn2_weights, self.ffn2_biases = [], [] + + for i in range(self.num_layers): + ln_scale_attr = self.get_attr(config.ln_scale_attrs, i) + ln_bias_attr = self.get_attr(config.ln_bias_attrs, i) + qkv_weight_attr = self.get_attr(config.qkv_weight_attrs, i) + + qkv_bias_attr = self.get_attr(config.qkv_bias_attrs, i) + linear_weight_attr = self.get_attr(config.linear_weight_attrs, i) + linear_bias_attr = self.get_attr(config.linear_bias_attrs, i) + + ffn_ln_scale_attr = self.get_attr(config.ffn_ln_scale_attrs, i) + ffn_ln_bias_attr = self.get_attr(config.ffn_ln_bias_attrs, i) + ffn1_weight_attr = self.get_attr(config.ffn1_weight_attrs, i) + ffn1_bias_attr = self.get_attr(config.ffn1_bias_attrs, i) + ffn2_weight_attr = self.get_attr(config.ffn2_weight_attrs, i) + ffn2_bias_attr = self.get_attr(config.ffn2_bias_attrs, i) + + ln_scale = self.create_parameter( + attr=ln_scale_attr, + shape=[config.embed_dim], + default_initializer=Constant(value=1.0), + dtype=self._norm_weight_dtype, + ) + ln_bias = None + if ln_bias_attr: + ln_bias = self.create_parameter( + attr=ln_bias_attr, + shape=[config.embed_dim], + is_bias=True, + dtype=self._norm_weight_dtype, + ) + + self.get_weight_shape(num_heads, dim_feedforward, config) + + qkv_weight = self.create_parameter( + shape=self.qkv_weight_shape, + attr=qkv_weight_attr, + dtype=self.create_params_type, + is_bias=False, + ) + + qkv_bias = None + if qkv_bias_attr: + qkv_bias = self.create_parameter( + shape=[3 * num_heads * self.head_dim], + attr=qkv_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + + linear_weight = self.create_parameter( + shape=self.linear_weight_shape, + attr=linear_weight_attr, + dtype=self.create_params_type, + is_bias=False, + ) + + linear_bias = None + if linear_bias_attr: + linear_bias = self.create_parameter( + shape=[config.embed_dim], + attr=linear_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + + ffn_ln_scale = self.create_parameter( + shape=[config.embed_dim], + attr=ffn_ln_scale_attr, + is_bias=False, + default_initializer=Constant(1.0), + dtype=self._norm_weight_dtype, + ) + + ffn_ln_bias = None + if ffn_ln_bias_attr: + ffn_ln_bias = self.create_parameter( + shape=[config.embed_dim], + attr=ffn_ln_bias_attr, + is_bias=True, + dtype=self._norm_weight_dtype, + ) + + ffn1_weight = self.create_parameter( + shape=self.ffn1_weight_shape, + attr=ffn1_weight_attr, + dtype=self.create_params_type, + is_bias=False, + ) + + ffn1_bias = None + if ffn1_bias_attr: + ffn1_bias = self.create_parameter( + shape=[dim_feedforward * 2] if config.activation.endswith("glu") else [dim_feedforward], + attr=ffn1_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + + ffn2_weight = self.create_parameter( + shape=self.ffn2_weight_shape, + attr=ffn2_weight_attr, + dtype=self.create_params_type, + is_bias=False, + ) + + ffn2_bias = None + if ffn2_bias_attr: + ffn2_bias = self.create_parameter( + shape=[config.embed_dim], + attr=ffn2_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + + # tensor model parallel + if config.nranks > 1: + # column parallel + _set_var_distributed(qkv_weight) + _set_var_distributed(qkv_bias) + _set_var_distributed(ffn1_weight) + _set_var_distributed(ffn1_bias) + # row parallel + _set_var_distributed(linear_weight) + _set_var_distributed(ffn2_weight) + + self.ln_scales.append(ln_scale) + self.ln_biases.append(ln_bias) + self.qkv_weights.append(qkv_weight) + self.qkv_biases.append(qkv_bias) + self.linear_weights.append(linear_weight) + self.linear_biases.append(linear_bias) + + self.ffn_ln_scales.append(ffn_ln_scale) + self.ffn_ln_biases.append(ffn_ln_bias) + self.ffn1_weights.append(ffn1_weight) + self.ffn1_biases.append(ffn1_bias) + self.ffn2_weights.append(ffn2_weight) + self.ffn2_biases.append(ffn2_bias) + + self._add_parameter(ln_scale) + self._add_parameter(ln_bias) + self._add_parameter(qkv_weight) + self._add_parameter(qkv_bias) + self._add_parameter(linear_weight) + self._add_parameter(linear_bias) + + self._add_parameter(ffn_ln_scale) + self._add_parameter(ffn_ln_bias) + self._add_parameter(ffn1_weight) + self._add_parameter(ffn1_bias) + self._add_parameter(ffn2_weight) + self._add_parameter(ffn2_bias) + + self.dropout_rate = config.dropout_rate + + from paddle.incubate.nn.functional import fused_linear + + self.linear = fused_linear + + def get_attr(self, attrs, idx): + if isinstance(attrs, (list, tuple)): + assert len(attrs) == self.num_layers + return attrs[idx] + return attrs + + def _add_parameter(self, param): + if param is None: + return + assert param.name not in self._parameters + self._parameters[param.name] = param + + def get_weight_shape(self, num_heads, dim_feedforward, config): + self.qkv_weight_shape = ( + [3 * num_heads * self.head_dim, self.embed_dim] + if config.trans_qkvw + else [self.embed_dim * 3 * num_heads, self.head_dim] + ) + self.linear_weight_shape = [num_heads * self.head_dim, self.embed_dim] + self.ffn1_weight_shape = ( + [self.embed_dim, dim_feedforward * 2] + if self.activation.endswith("glu") + else [self.embed_dim, dim_feedforward] + ) + self.ffn2_weight_shape = [dim_feedforward, self.embed_dim] + + def get_weight_create_dype(self): + return self._dtype + + def compute_layernorm_before_qkv(self, src, i): + if i == 0: + ln_out = self.norm_func(src, self.ln_scales[i], self.ln_biases[i], self._epsilon, begin_norm_axis=1) + else: + ln_out = src + + return ln_out + + def compute_qkv_linear(self, ln_out, i): + return self.linear(ln_out, self.qkv_weights[i], self.qkv_biases[i], transpose_weight=True) + + def compute_qkv(self, src, residual_input, i): + ln_out = self.compute_layernorm_before_qkv(src, i) + qkv_out = self.compute_qkv_linear(ln_out, i) + return qkv_out, residual_input + + def compute_fmha( + self, + qkv_out, + padding_offset, + seq_lens, + input_ids, + rotary_embs, + rotary_emb_dims, + caches, + pre_caches, + pre_caches_length, + attn_mask, + i, + ): + """ + qkv: bsz, seq_len, 3, numhead, headsize -> + q_out: bsz, numhead, seq_len, headsize + kv_out: 2, bsz, numhead, seq_len, headsize + """ + q_out, k_out, v_out = qkv_transpose_split( + qkv_out, padding_offset, seq_lens, input_ids, self.num_heads // self.nranks, self.head_dim + ) + + # rotary emb (inplace) + if rotary_embs is not None: + encode_rotary_qk( + q_out, + k_out, + rotary_embs, + seq_lens, + rotary_emb_dims=rotary_emb_dims, + use_neox=self.use_neox_rotary_style, + ) + + if pre_caches is not None: + k_out = paddle.concat([pre_caches[i][0], k_out], axis=2) + v_out = paddle.concat([pre_caches[i][1], v_out], axis=2) + + # write cache kv (inplace) + write_cache_kv(k_out, v_out, caches[i], seq_lens + pre_caches_length) + + # cutlass fmha + qktv_out = variable_length_memory_efficient_attention( + q_out, + k_out, + v_out, + seq_lens, + seq_lens + pre_caches_length, + mask=attn_mask, + scale=float(self.head_dim**-0.5), + ) + + return transpose_remove_padding(qktv_out, seq_lens, padding_offset) + + def compute_mmha(self, qkv_out, caches, attn_mask, seq_lens, rotary_embs, rotary_emb_dims, i): + return masked_multihead_attention( + x=qkv_out, + cache_kv=caches[i], + src_mask=attn_mask, + sequence_lengths=seq_lens, + rotary_tensor=rotary_embs, + rotary_emb_dims=rotary_emb_dims, + use_neox_rotary_style=self.use_neox_rotary_style, + )[0] + + def compute_out_linear(self, fmha_out, i): + return paddle.matmul(fmha_out, self.linear_weights[i]) + + def compute_attn( + self, + time_step, + qkv_out, + padding_offset, + seq_lens, + input_ids, + rotary_embs, + rotary_emb_dims, + caches, + pre_caches, + pre_caches_length, + attn_mask, + i, + ): + # fmha compute + if time_step is None: # context + fmha_out = self.compute_fmha( + qkv_out, + padding_offset, + seq_lens, + input_ids, + rotary_embs, + rotary_emb_dims, + caches, + pre_caches, + pre_caches_length, + attn_mask, + i, + ) + + else: + fmha_out = self.compute_mmha(qkv_out, caches, attn_mask, seq_lens, rotary_embs, rotary_emb_dims, i) + + out_linear_out = self.compute_out_linear(fmha_out, i) + + return out_linear_out + + def compute_ffn_layernorm(self, out_linear_out, residual_input, i): + norm_out = self.norm_func( + out_linear_out, + norm_weight=self.ffn_ln_scales[i], + norm_bias=self.ffn_ln_biases[i], + epsilon=self._epsilon, + begin_norm_axis=1, + bias=self.linear_biases[i], + residual=residual_input, + ) + tmp_out, residual_input = norm_out[0], norm_out[1] + + return tmp_out, residual_input + + def compute_ffn1(self, tmp_out, i): + return paddle.matmul(tmp_out, self.ffn1_weights[i]) + + def compute_ffn2(self, ffn1_out, i): + return paddle.matmul(ffn1_out, self.ffn2_weights[i]) + + def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layers): + if i != num_layers - 1: + norm_out = self.norm_func( + ffn2_out, + norm_weight=self.ln_scales[i + 1], + norm_bias=self.ln_biases[i + 1], + epsilon=self._epsilon, + begin_norm_axis=1, + bias=self.ffn2_biases[i], + residual=residual_input, + ) + tmp_out, residual_input = norm_out[0], norm_out[1] + else: + tmp_out = fused_layer_norm( + ffn2_out, + norm_weight=None, + norm_bias=None, + epsilon=self._epsilon, + begin_norm_axis=1, + bias=self.ffn2_biases[i], + residual=residual_input, + )[0] + return tmp_out, residual_input + + def forward( + self, + input_ids, + src, + cum_offsets=None, + padding_offset=None, + attn_mask=None, + caches=None, + pre_caches=None, + pre_caches_length=0, + rotary_embs=None, + rotary_emb_dims=0, + seq_lens=None, + time_step=None, + ): + r""" + Applies multi transformer layers on the input. + + Parameters: + src (Tensor): The input of Transformer layers. It is + a tensor with shape `[batch_size, sequence_length, d_model]`. + The data type should be float16 or float32. + attn_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + `[batch_size, 1, sequence_length, sequence_length]`. It can be + None when nothing wanted or needed to be prevented attention to. + Default None. + caches (list(Tensor)|tuple(Tensor), optional): The cache structure + tensors for the inference generation model. It is only used for + inference and should be None for training. The shape is + `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None. + pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches + for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. + rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. + rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, + 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0. + seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None. + time_step (Tensor, optional): The time step tensor for the generation + model. Which used in decode stage, to represent the time step, + that is, the real seq_len of CacheKV. The shape is `[1]`, must be + in CPUPlace. Default None. + + Returns: + Tensor|tuple: If `caches` is None, return a tensor that has + the same shape and data type with `src`, representing the output + of Transformer layers. If `caches` is not None, return the + tuple (output, caches), which output is the output of + Transformer layers, caches is inplace with input `caches`. + """ + if caches is not None: + assert len(caches) == len(self.qkv_weights) + + residual_input = src + for i in range(len(caches)): + qkv_out, residual_input = self.compute_qkv(src, residual_input, i) + out_linear_out = self.compute_attn( + time_step, + qkv_out, + padding_offset, + seq_lens, + input_ids, + rotary_embs, + rotary_emb_dims, + caches, + pre_caches, + pre_caches_length, + attn_mask, + i, + ) + # all_reduce + if self.nranks > 1: + dist.all_reduce(out_linear_out) + + # ffn layernorm + tmp_out, residual_input = self.compute_ffn_layernorm(out_linear_out, residual_input, i) + + # ffn1 matmul + ffn1_out = self.compute_ffn1(tmp_out, i) + ffn1_out = fused_act_bias_wrapper(ffn1_out, self.ffn1_biases[i], act_method=self.activation) + + # ffn2 matmul + ffn2_out = self.compute_ffn2(ffn1_out, i) + + # all_reduce + if self.nranks > 1: + dist.all_reduce(ffn2_out) + + # norm + residual_add_bias + tmp_out, residual_input = self.compute_bias_residual_layernorm(ffn2_out, residual_input, i, len(caches)) + src = tmp_out + + if time_step is None: + out = rebuild_padding(tmp_out, cum_offsets, seq_lens, input_ids) + else: + out = tmp_out + return out, caches + + +class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase): + def __init__(self, config: FusedMultiTransformerConfig): + super().__init__(config) + + def compute_qkv(self, src, residual_input, i): + qkv_out = self.compute_qkv_linear(src, i) + return qkv_out, src + + def compute_ffn_layernorm(self, out_linear_out, residual_input, i): + tmp_out = self.norm_func( + out_linear_out, + norm_weight=self.ln_scales[i], + norm_bias=self.ln_biases[i], + epsilon=self._epsilon, + residual_alpha=self._residual_alpha, + begin_norm_axis=1, + bias=self.linear_biases[i], + residual=residual_input, + )[0] + + return tmp_out, tmp_out + + def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layers): + tmp_out = self.norm_func( + ffn2_out, + norm_weight=self.ffn_ln_scales[i], + norm_bias=self.ffn_ln_biases[i], + epsilon=self._epsilon, + residual_alpha=self._residual_alpha, + begin_norm_axis=1, + bias=self.ffn2_biases[i], + residual=residual_input, + )[0] + return tmp_out, tmp_out + + +class FusedMultiTransformerWeightOnly(FusedMultiTransformerBase): + def __init__(self, config: FusedMultiTransformerConfig): + super().__init__(config) + self.quant_bits = config.quant_bits + + assert self.quant_bits != -1 + self.weight_dtype = "int" + str(self.quant_bits) + + self.qkv_weights_scale = [] + self.linear_weights_scale = [] + self.ffn1_weights_scale = [] + self.ffn2_weights_scale = [] + + for i in range(self.num_layers): + + qkv_weight_scale_attr = self.get_attr(config.qkv_weight_scale_attrs, i) + linear_weight_scale_attr = self.get_attr(config.linear_weight_scale_attrs, i) + ffn1_weight_scale_attr = self.get_attr(config.ffn1_weight_scale_attrs, i) + ffn2_weight_scale_attr = self.get_attr(config.ffn2_weight_scale_attrs, i) + + qkv_weight_scale = self.create_parameter( + shape=[3 * config.num_heads * self.head_dim], + attr=qkv_weight_scale_attr, + dtype=paddle.float32, + is_bias=False, + ) + + linear_weight_scale = self.create_parameter( + shape=[config.embed_dim], + attr=linear_weight_scale_attr, + dtype=paddle.float32, + is_bias=False, + ) + + ffn1_weight_scale = self.create_parameter( + shape=[config.dim_feedforward * 2], + attr=ffn1_weight_scale_attr, + dtype=paddle.float32, + is_bias=False, + ) + + ffn2_weight_scale = self.create_parameter( + shape=[config.embed_dim], + attr=ffn2_weight_scale_attr, + dtype=paddle.float32, + is_bias=False, + ) + + self.qkv_weights_scale.append(qkv_weight_scale) + self.linear_weights_scale.append(linear_weight_scale) + self.ffn1_weights_scale.append(ffn1_weight_scale) + self.ffn2_weights_scale.append(ffn2_weight_scale) + + self._add_parameter(qkv_weight_scale) + self._add_parameter(linear_weight_scale) + self._add_parameter(ffn1_weight_scale) + self._add_parameter(ffn2_weight_scale) + + def get_weight_create_dype(self): + return "int8" # If use weightonly int4, params dtype is int8, and one of the dimension will be half. + + def get_weight_shape(self, num_heads, dim_feedforward, config): + super().get_weight_shape(num_heads, dim_feedforward, config) + + self.linear_weight_shape = [self.embed_dim, num_heads * self.head_dim] + self.ffn1_weight_shape = ( + [dim_feedforward * 2, self.embed_dim] + if self.activation.endswith("glu") + else [dim_feedforward, self.embed_dim] + ) + self.ffn2_weight_shape = [self.embed_dim, dim_feedforward] + + if config.quant_bits == 4: + self.qkv_weight_shape[0] //= 2 + self.linear_weight_shape[0] //= 2 + self.ffn1_weight_shape[0] //= 2 + self.ffn2_weight_shape[0] //= 2 + + def compute_qkv_linear(self, ln_out, i): + return weight_only_linear( + ln_out, + weight=self.qkv_weights[i], + bias=self.qkv_biases[i], + weight_scale=self.qkv_weights_scale[i], + weight_dtype=self.weight_dtype, + ) + + def compute_out_linear(self, fmha_out, i): + return weight_only_linear( + fmha_out, + weight=self.linear_weights[i], + weight_scale=self.linear_weights_scale[i], + weight_dtype=self.weight_dtype, + ) + + def compute_ffn1(self, tmp_out, i): + return weight_only_linear( + tmp_out, + weight=self.ffn1_weights[i], + weight_scale=self.ffn1_weights_scale[i], + weight_dtype=self.weight_dtype, + ) + + def compute_ffn2(self, ffn1_out, i): + return weight_only_linear( + ffn1_out, + weight=self.ffn2_weights[i], + weight_scale=self.ffn2_weights_scale[i], + weight_dtype=self.weight_dtype, + ) + + +class FusedMultiTransformerWeightOnlyPostLayernorm( + FusedMultiTransformerWeightOnly, FusedMultiTransformerPostLayernorm +): + def __init__(self, config: FusedMultiTransformerConfig): + super().__init__(config) + + class FusedMultiTransformer(Layer): def __init__( self, @@ -531,7 +1273,7 @@ def forward( """ if caches is not None: assert len(caches) == len(self.qkv_weights) - bias_residual_input = src + residual_input = src ln_out = src for i in range(len(caches)): if self.normalize_before is True: @@ -630,9 +1372,9 @@ def forward( epsilon=self._epsilon, begin_norm_axis=1, bias=self.linear_biases[i], - residual=bias_residual_input, + residual=residual_input, ) - tmp_out, bias_residual_input = norm_out[0], norm_out[1] + tmp_out, residual_input = norm_out[0], norm_out[1] else: tmp_out = self.norm_func( out_linear_out, @@ -682,9 +1424,9 @@ def forward( epsilon=self._epsilon, begin_norm_axis=1, bias=self.ffn2_biases[i], - residual=bias_residual_input, + residual=residual_input, ) - tmp_out, bias_residual_input = norm_out[0], norm_out[1] + tmp_out, residual_input = norm_out[0], norm_out[1] else: tmp_out = fused_layer_norm( ffn2_out, @@ -693,7 +1435,7 @@ def forward( epsilon=self._epsilon, begin_norm_axis=1, bias=self.ffn2_biases[i], - residual=bias_residual_input, + residual=residual_input, )[0] else: tmp_out = self.norm_func( diff --git a/paddlenlp/experimental/transformers/gpt/modeling.py b/paddlenlp/experimental/transformers/gpt/modeling.py index d41ea16b65f2..4dd653452bd0 100644 --- a/paddlenlp/experimental/transformers/gpt/modeling.py +++ b/paddlenlp/experimental/transformers/gpt/modeling.py @@ -19,7 +19,8 @@ from paddlenlp_ops import get_padding_offset from paddlenlp.experimental.transformers.fused_transformer_layers import ( - FusedMultiTransformer, + FusedMultiTransformerBase, + FusedMultiTransformerConfig, ) from paddlenlp.experimental.transformers.generation_utils import ( GenerationInferenceModel, @@ -111,7 +112,7 @@ def __init__(self, config: GPTConfig): ffn2_bias_attrs = [ paddle.ParamAttr(name="gpt.decoder.layers.{}.linear2.bias".format(i)) for i in range(self.num_layers) ] - self.transformer_block = FusedMultiTransformer( + transformer_config = FusedMultiTransformerConfig( config.hidden_size, config.num_attention_heads, 4 * config.hidden_size, @@ -134,6 +135,7 @@ def __init__(self, config: GPTConfig): epsilon=1e-5, norm_type="layernorm", ) + self.transformer_block = FusedMultiTransformerBase(transformer_config) self.norm = nn.LayerNorm(config.hidden_size, epsilon=1e-5) def get_input_embeddings(self): diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 69dcd855c678..849b9f1c48e8 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -21,7 +21,9 @@ from paddlenlp_ops import fused_get_rotary_embedding, get_padding_offset from paddlenlp.experimental.transformers.fused_transformer_layers import ( - FusedMultiTransformer, + FusedMultiTransformerBase, + FusedMultiTransformerConfig, + FusedMultiTransformerWeightOnly, ) from paddlenlp.experimental.transformers.generation_utils import ( GenerationInferenceModel, @@ -161,7 +163,7 @@ def __init__(self, config: LlamaConfig): paddle.ParamAttr(name="fusellama.{}.ffn2_weight_scale".format(i)) for i in range(self.num_layers) ] - self.transformer_block = FusedMultiTransformer( + transformer_config = FusedMultiTransformerConfig( self.hidden_size, self.num_attention_heads, self.intermediate_size, @@ -184,6 +186,12 @@ def __init__(self, config: LlamaConfig): norm_type="rmsnorm", use_neox_rotary_style=True, ) + + if self.use_weight_only: + self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config) + else: + self.transformer_block = FusedMultiTransformerBase(transformer_config) + self.norm = FusedLlamaRMSNorm(config) self.cache_kvs = None diff --git a/paddlenlp/experimental/transformers/opt/modeling.py b/paddlenlp/experimental/transformers/opt/modeling.py index e333a8d6071e..ac1a321e4ccd 100644 --- a/paddlenlp/experimental/transformers/opt/modeling.py +++ b/paddlenlp/experimental/transformers/opt/modeling.py @@ -21,7 +21,8 @@ from paddlenlp_ops import get_padding_offset from paddlenlp.experimental.transformers.fused_transformer_layers import ( - FusedMultiTransformer, + FusedMultiTransformerBase, + FusedMultiTransformerConfig, ) from paddlenlp.experimental.transformers.generation_utils import ( GenerationInferenceModel, @@ -110,7 +111,7 @@ def __init__(self, config: OPTConfig): for i in range(config.num_hidden_layers) ] - self.transformer_block = FusedMultiTransformer( + transformer_config = FusedMultiTransformerConfig( config.hidden_size, config.num_attention_heads, config.intermediate_size, @@ -135,6 +136,8 @@ def __init__(self, config: OPTConfig): epsilon=self.epsilon, ) + self.transformer_block = FusedMultiTransformerBase(transformer_config) + def get_input_embeddings(self): return self.embeddings.word_embeddings diff --git a/tests/llm/test_predictor.py b/tests/llm/test_predictor.py index c64310b36ac7..201d714be77b 100644 --- a/tests/llm/test_predictor.py +++ b/tests/llm/test_predictor.py @@ -19,8 +19,7 @@ from parameterized import parameterized_class from paddlenlp.transformers import AutoTokenizer, LlamaForCausalLM - -from .testing_utils import LLMTest +from tests.llm.testing_utils import LLMTest @parameterized_class( @@ -53,3 +52,24 @@ def test_predictor(self): self.assertGreaterEqual(full_match / len(result_0), 0.25) self.assertGreater(count / len(result_0), 0.4) + + def test_wint8(self): + self.run_predictor({"inference_model": True, "quant_type": "weight_only_int8"}) + result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) + self.run_predictor({"inference_model": False}) + result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) + + assert len(result_0) == len(result_1) + + count, full_match = 0, 0 + for inference_item, no_inference_item in zip(result_0, result_1): + min_length = min(len(inference_item), len(no_inference_item)) + count += int(inference_item[min_length // 2] == no_inference_item[min_length // 2]) + full_match += int(inference_item[:min_length] == no_inference_item[:min_length]) + + self.assertGreaterEqual(full_match / len(result_0), 0.15) + self.assertGreater(count / len(result_0), 0.4) + + +if __name__ == "__main__": + unittest.main() From 89ce638b58d8bbcc81bfe688c3f77d4dfb15e67b Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Tue, 10 Oct 2023 19:37:37 +0800 Subject: [PATCH 2/3] delete origin class --- .../transformers/fused_transformer_layers.py | 584 ------------------ 1 file changed, 584 deletions(-) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 0dfd9143c159..695449e62a86 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -872,587 +872,3 @@ class FusedMultiTransformerWeightOnlyPostLayernorm( ): def __init__(self, config: FusedMultiTransformerConfig): super().__init__(config) - - -class FusedMultiTransformer(Layer): - def __init__( - self, - embed_dim, - num_heads, - dim_feedforward, - quant_bits=-1, # -1 means use Half precision. - dropout_rate=0.0, - activation="gelu", - norm_type="layernorm", - use_neox_rotary_style=False, - normalize_before=True, - ln_scale_attrs=None, - ln_bias_attrs=None, - qkv_weight_attrs=None, - qkv_weight_scale_attrs=None, - qkv_bias_attrs=None, - linear_weight_attrs=None, - linear_weight_scale_attrs=None, - linear_bias_attrs=None, - ffn_ln_scale_attrs=None, - ffn_ln_bias_attrs=None, - ffn1_weight_attrs=None, - ffn1_weight_scale_attrs=None, - ffn1_bias_attrs=None, - ffn2_weight_attrs=None, - ffn2_weight_scale_attrs=None, - ffn2_bias_attrs=None, - epsilon=1e-5, - residual_alpha=1.0, - num_layers=-1, - nranks=1, - trans_qkvw=True, - ring_id=-1, - name=None, - ): - super().__init__() - - assert embed_dim > 0, "Expected embed_dim to be greater than 0, " "but received {}".format(embed_dim) - assert num_heads > 0, "Expected nhead to be greater than 0, " "but received {}".format(num_heads) - assert dim_feedforward > 0, "Expected dim_feedforward to be greater than 0, but received {}".format( - dim_feedforward - ) - - self.normalize_before = normalize_before - self._dtype = self._helper.get_default_dtype() - self._epsilon = epsilon - self._residual_alpha = residual_alpha - self._trans_qkvw = trans_qkvw - self._ring_id = ring_id - self.nranks = nranks - self.norm_type = norm_type - if norm_type == "layernorm": - self.norm_func = fused_layer_norm - elif norm_type == "rmsnorm": - self.norm_func = fused_rms_norm - else: - raise NotImplementedError("Only support norm type of [layernorm, rmsnorm]") - self.use_neox_rotary_style = use_neox_rotary_style - self._norm_weight_dtype = "float32" if self.norm_type == "layernorm" else self._dtype - - self.embed_dim = embed_dim - self.num_heads = num_heads - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" - - # tensor model parallel - if nranks > 1: - assert ring_id != -1 - assert num_heads % nranks == 0 - assert dim_feedforward % nranks == 0 - num_heads = num_heads // nranks - dim_feedforward = dim_feedforward // nranks - self._dim_feedforward = dim_feedforward - - if isinstance(qkv_weight_attrs, (list, tuple)): - num_layers = len(qkv_weight_attrs) - assert num_layers > 0 - - self.quant_bits = quant_bits - self.use_weight_only = False - self.weight_dtype = self._dtype - self.create_params_type = self._dtype - - if self.quant_bits != -1: - self.use_weight_only = True - self.create_params_type = ( - "int8" # If use weightonly int4, params dtype is int8, and one of the dimension will be half. - ) - self.weight_dtype = "int" + str(self.quant_bits) - - self.ln_scales, self.ln_biases = [], [] - self.qkv_weights, self.qkv_weights_scale, self.qkv_biases = [], [], [] - self.linear_weights, self.linear_weights_scale, self.linear_biases = [], [], [] - self.ffn_ln_scales, self.ffn_ln_biases = [], [] - self.ffn1_weights, self.ffn1_weights_scale, self.ffn1_biases = [], [], [] - self.ffn2_weights, self.ffn2_weights_scale, self.ffn2_biases = [], [], [] - - def get_attr(attrs, idx): - if isinstance(attrs, (list, tuple)): - assert len(attrs) == num_layers - return attrs[idx] - return attrs - - def _add_parameter(param): - if param is None: - return - assert param.name not in self._parameters - self._parameters[param.name] = param - - for i in range(num_layers): - ln_scale_attr = get_attr(ln_scale_attrs, i) - ln_bias_attr = get_attr(ln_bias_attrs, i) - qkv_weight_attr = get_attr(qkv_weight_attrs, i) - qkv_weight_scale_attr = get_attr(qkv_weight_scale_attrs, i) - - qkv_bias_attr = get_attr(qkv_bias_attrs, i) - linear_weight_attr = get_attr(linear_weight_attrs, i) - linear_weight_scale_attr = get_attr(linear_weight_scale_attrs, i) - linear_bias_attr = get_attr(linear_bias_attrs, i) - - ffn_ln_scale_attr = get_attr(ffn_ln_scale_attrs, i) - ffn_ln_bias_attr = get_attr(ffn_ln_bias_attrs, i) - ffn1_weight_attr = get_attr(ffn1_weight_attrs, i) - ffn1_weight_scale_attr = get_attr(ffn1_weight_scale_attrs, i) - ffn1_bias_attr = get_attr(ffn1_bias_attrs, i) - ffn2_weight_attr = get_attr(ffn2_weight_attrs, i) - ffn2_weight_scale_attr = get_attr(ffn2_weight_scale_attrs, i) - ffn2_bias_attr = get_attr(ffn2_bias_attrs, i) - - ln_scale = self.create_parameter( - attr=ln_scale_attr, - shape=[embed_dim], - default_initializer=Constant(value=1.0), - dtype=self._norm_weight_dtype, - ) - ln_bias = None - if ln_bias_attr: - ln_bias = self.create_parameter( - attr=ln_bias_attr, - shape=[embed_dim], - is_bias=True, - dtype=self._norm_weight_dtype, - ) - - # Note(Zhengzekang): Weightonly need weight is ColMajor layout. - qkv_weight_shape = ( - [3 * num_heads * self.head_dim, embed_dim] - if trans_qkvw - else [embed_dim * 3 * num_heads, self.head_dim] - ) - qkv_weight_scale = None - if self.use_weight_only: - if self.quant_bits == 4: - qkv_weight_shape[0] //= 2 - - qkv_weight_scale = self.create_parameter( - shape=[3 * num_heads * self.head_dim], - attr=qkv_weight_scale_attr, - dtype=paddle.float32, - is_bias=False, - ) - - qkv_weight = self.create_parameter( - shape=qkv_weight_shape, - attr=qkv_weight_attr, - dtype=self.create_params_type, - is_bias=False, - ) - - qkv_bias = None - if qkv_bias_attr: - qkv_bias = self.create_parameter( - shape=[3 * num_heads * self.head_dim], - attr=qkv_bias_attr, - dtype=self._dtype, - is_bias=True, - ) - - linear_weight_shape = [num_heads * self.head_dim, embed_dim] - linear_weight_scale = None - if self.use_weight_only: - linear_weight_shape = [embed_dim, num_heads * self.head_dim] - if self.quant_bits == 4: - linear_weight_shape[0] //= 2 - - linear_weight_scale = self.create_parameter( - shape=[embed_dim], - attr=linear_weight_scale_attr, - dtype=paddle.float32, - is_bias=False, - ) - linear_weight = self.create_parameter( - shape=linear_weight_shape, - attr=linear_weight_attr, - dtype=self.create_params_type, - is_bias=False, - ) - - linear_bias = None - if linear_bias_attr: - linear_bias = self.create_parameter( - shape=[embed_dim], - attr=linear_bias_attr, - dtype=self._dtype, - is_bias=True, - ) - - ffn_ln_scale = self.create_parameter( - shape=[embed_dim], - attr=ffn_ln_scale_attr, - is_bias=False, - default_initializer=Constant(1.0), - dtype=self._norm_weight_dtype, - ) - - ffn_ln_bias = None - if ffn_ln_bias_attr: - ffn_ln_bias = self.create_parameter( - shape=[embed_dim], - attr=ffn_ln_bias_attr, - is_bias=True, - dtype=self._norm_weight_dtype, - ) - - ffn1_weight_shape = ( - [embed_dim, dim_feedforward * 2] if activation.endswith("glu") else [embed_dim, dim_feedforward] - ) - ffn1_weight_scale = None - if self.use_weight_only: - ffn1_weight_shape = ( - [dim_feedforward * 2, embed_dim] if activation.endswith("glu") else [dim_feedforward, embed_dim] - ) - if self.quant_bits == 4: - ffn1_weight_shape[0] //= 2 - - ffn1_weight_scale = self.create_parameter( - shape=[dim_feedforward * 2], - attr=ffn1_weight_scale_attr, - dtype=paddle.float32, - is_bias=False, - ) - ffn1_weight = self.create_parameter( - shape=ffn1_weight_shape, - attr=ffn1_weight_attr, - dtype=self.create_params_type, - is_bias=False, - ) - - ffn1_bias = None - if ffn1_bias_attr: - ffn1_bias = self.create_parameter( - shape=[dim_feedforward * 2] if activation.endswith("glu") else [dim_feedforward], - attr=ffn1_bias_attr, - dtype=self._dtype, - is_bias=True, - ) - - ffn2_weight_shape = [dim_feedforward, embed_dim] - ffn2_weight_scale = None - if self.use_weight_only: - ffn2_weight_shape = [embed_dim, dim_feedforward] - if self.quant_bits == 4: - ffn2_weight_shape[0] //= 2 - - ffn2_weight_scale = self.create_parameter( - shape=[embed_dim], - attr=ffn2_weight_scale_attr, - dtype=paddle.float32, - is_bias=False, - ) - - ffn2_weight = self.create_parameter( - shape=ffn2_weight_shape, - attr=ffn2_weight_attr, - dtype=self.create_params_type, - is_bias=False, - ) - - ffn2_bias = None - if ffn2_bias_attr: - ffn2_bias = self.create_parameter( - shape=[embed_dim], - attr=ffn2_bias_attr, - dtype=self._dtype, - is_bias=True, - ) - - # tensor model parallel - if nranks > 1: - # column parallel - _set_var_distributed(qkv_weight) - _set_var_distributed(qkv_bias) - _set_var_distributed(ffn1_weight) - _set_var_distributed(ffn1_bias) - # row parallel - _set_var_distributed(linear_weight) - _set_var_distributed(ffn2_weight) - - self.ln_scales.append(ln_scale) - self.ln_biases.append(ln_bias) - self.qkv_weights.append(qkv_weight) - self.qkv_biases.append(qkv_bias) - self.linear_weights.append(linear_weight) - self.linear_biases.append(linear_bias) - - self.ffn_ln_scales.append(ffn_ln_scale) - self.ffn_ln_biases.append(ffn_ln_bias) - self.ffn1_weights.append(ffn1_weight) - self.ffn1_biases.append(ffn1_bias) - self.ffn2_weights.append(ffn2_weight) - self.ffn2_biases.append(ffn2_bias) - - if self.use_weight_only: - self.qkv_weights_scale.append(qkv_weight_scale) - self.linear_weights_scale.append(linear_weight_scale) - self.ffn1_weights_scale.append(ffn1_weight_scale) - self.ffn2_weights_scale.append(ffn2_weight_scale) - - _add_parameter(ln_scale) - _add_parameter(ln_bias) - _add_parameter(qkv_weight) - _add_parameter(qkv_bias) - _add_parameter(linear_weight) - _add_parameter(linear_bias) - - _add_parameter(ffn_ln_scale) - _add_parameter(ffn_ln_bias) - _add_parameter(ffn1_weight) - _add_parameter(ffn1_bias) - _add_parameter(ffn2_weight) - _add_parameter(ffn2_bias) - - if self.use_weight_only: - _add_parameter(qkv_weight_scale) - _add_parameter(linear_weight_scale) - _add_parameter(ffn1_weight_scale) - _add_parameter(ffn2_weight_scale) - - self.dropout_rate = dropout_rate - self.activation = activation - self.name = name - - from paddle.incubate.nn.functional import fused_linear - - self.linear = fused_linear - - def forward( - self, - input_ids, - src, - cum_offsets=None, - padding_offset=None, - attn_mask=None, - caches=None, - pre_caches=None, - pre_caches_length=0, - rotary_embs=None, - rotary_emb_dims=0, - seq_lens=None, - time_step=None, - ): - r""" - Applies multi transformer layers on the input. - - Parameters: - src (Tensor): The input of Transformer layers. It is - a tensor with shape `[batch_size, sequence_length, d_model]`. - The data type should be float16 or float32. - attn_mask (Tensor, optional): A tensor used in multi-head attention - to prevents attention to some unwanted positions, usually the - paddings or the subsequent positions. It is a tensor with shape - `[batch_size, 1, sequence_length, sequence_length]`. It can be - None when nothing wanted or needed to be prevented attention to. - Default None. - caches (list(Tensor)|tuple(Tensor), optional): The cache structure - tensors for the inference generation model. It is only used for - inference and should be None for training. The shape is - `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None. - pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches - for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. - rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. - rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, - 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0. - seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None. - time_step (Tensor, optional): The time step tensor for the generation - model. Which used in decode stage, to represent the time step, - that is, the real seq_len of CacheKV. The shape is `[1]`, must be - in CPUPlace. Default None. - - Returns: - Tensor|tuple: If `caches` is None, return a tensor that has - the same shape and data type with `src`, representing the output - of Transformer layers. If `caches` is not None, return the - tuple (output, caches), which output is the output of - Transformer layers, caches is inplace with input `caches`. - """ - if caches is not None: - assert len(caches) == len(self.qkv_weights) - residual_input = src - ln_out = src - for i in range(len(caches)): - if self.normalize_before is True: - # layernorm - if i == 0: - ln_out = self.norm_func( - src, self.ln_scales[i], self.ln_biases[i], self._epsilon, begin_norm_axis=1 - ) - - # qkv compute - if self.use_weight_only: - qkv_out = weight_only_linear( - ln_out, - weight=self.qkv_weights[i], - bias=self.qkv_biases[i], - weight_scale=self.qkv_weights_scale[i], - weight_dtype=self.weight_dtype, - ) - else: - qkv_out = self.linear(ln_out, self.qkv_weights[i], self.qkv_biases[i], transpose_weight=True) - - # fmha compute - if time_step is None: # context - """ - qkv: bsz, seq_len, 3, numhead, headsize -> - q_out: bsz, numhead, seq_len, headsize - kv_out: 2, bsz, numhead, seq_len, headsize - """ - q_out, k_out, v_out = qkv_transpose_split( - qkv_out, padding_offset, seq_lens, input_ids, self.num_heads // self.nranks, self.head_dim - ) - - # rotary emb (inplace) - if rotary_embs is not None: - encode_rotary_qk( - q_out, - k_out, - rotary_embs, - seq_lens, - rotary_emb_dims=rotary_emb_dims, - use_neox=self.use_neox_rotary_style, - ) - - if pre_caches is not None: - k_out = paddle.concat([pre_caches[i][0], k_out], axis=2) - v_out = paddle.concat([pre_caches[i][1], v_out], axis=2) - - # write cache kv (inplace) - write_cache_kv(k_out, v_out, caches[i], seq_lens + pre_caches_length) - - # cutlass fmha - qktv_out = variable_length_memory_efficient_attention( - q_out, - k_out, - v_out, - seq_lens, - seq_lens + pre_caches_length, - mask=attn_mask, - scale=float(self.head_dim**-0.5), - ) - - fmha_out = transpose_remove_padding(qktv_out, seq_lens, padding_offset) - - else: - fmha_out = masked_multihead_attention( - x=qkv_out, - cache_kv=caches[i], - src_mask=attn_mask, - sequence_lengths=seq_lens, - rotary_tensor=rotary_embs, - rotary_emb_dims=rotary_emb_dims, - use_neox_rotary_style=self.use_neox_rotary_style, - )[0] - - # out_linear - if self.use_weight_only: - out_linear_out = weight_only_linear( - fmha_out, - weight=self.linear_weights[i], - weight_scale=self.linear_weights_scale[i], - weight_dtype=self.weight_dtype, - ) - else: - out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i]) - - # all_reduce - if self.nranks > 1: - dist.all_reduce(out_linear_out) - - # norm + residual_add_bias - if self.normalize_before is True: - norm_out = self.norm_func( - out_linear_out, - norm_weight=self.ffn_ln_scales[i], - norm_bias=self.ffn_ln_biases[i], - epsilon=self._epsilon, - begin_norm_axis=1, - bias=self.linear_biases[i], - residual=residual_input, - ) - tmp_out, residual_input = norm_out[0], norm_out[1] - else: - tmp_out = self.norm_func( - out_linear_out, - norm_weight=self.ln_scales[i], - norm_bias=self.ln_biases[i], - epsilon=self._epsilon, - residual_alpha=self._residual_alpha, - begin_norm_axis=1, - bias=self.linear_biases[i], - residual=ln_out, - )[0] - - # ffn1 matmul - if self.use_weight_only: - ffn1_out = weight_only_linear( - tmp_out, - weight=self.ffn1_weights[i], - weight_scale=self.ffn1_weights_scale[i], - weight_dtype=self.weight_dtype, - ) - else: - ffn1_out = paddle.matmul(tmp_out, self.ffn1_weights[i]) - ffn1_out = fused_act_bias_wrapper(ffn1_out, self.ffn1_biases[i], act_method=self.activation) - - # ffn2 matmul - if self.use_weight_only: - ffn2_out = weight_only_linear( - ffn1_out, - weight=self.ffn2_weights[i], - weight_scale=self.ffn2_weights_scale[i], - weight_dtype=self.weight_dtype, - ) - else: - ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i]) - - # all_reduce - if self.nranks > 1: - dist.all_reduce(ffn2_out) - - # norm + residual_add_bias - if self.normalize_before is True: - if i != len(caches) - 1: - norm_out = self.norm_func( - ffn2_out, - norm_weight=self.ln_scales[i + 1], - norm_bias=self.ln_biases[i + 1], - epsilon=self._epsilon, - begin_norm_axis=1, - bias=self.ffn2_biases[i], - residual=residual_input, - ) - tmp_out, residual_input = norm_out[0], norm_out[1] - else: - tmp_out = fused_layer_norm( - ffn2_out, - norm_weight=None, - norm_bias=None, - epsilon=self._epsilon, - begin_norm_axis=1, - bias=self.ffn2_biases[i], - residual=residual_input, - )[0] - else: - tmp_out = self.norm_func( - ffn2_out, - norm_weight=self.ffn_ln_scales[i], - norm_bias=self.ffn_ln_biases[i], - epsilon=self._epsilon, - residual_alpha=self._residual_alpha, - begin_norm_axis=1, - bias=self.ffn2_biases[i], - residual=tmp_out, - )[0] - - ln_out = tmp_out - - if time_step is None: - out = rebuild_padding(tmp_out, cum_offsets, seq_lens, input_ids) - else: - out = tmp_out - return out, caches From 86283ade3a7d215b368c9a19c93b9a678aaaa691 Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Thu, 12 Oct 2023 10:37:12 +0800 Subject: [PATCH 3/3] code refine --- .../transformers/fused_transformer_layers.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 695449e62a86..c93a87e53a89 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -252,9 +252,10 @@ def __init__(self, config: FusedMultiTransformerConfig): dim_feedforward = config.dim_feedforward // config.nranks self._dim_feedforward = dim_feedforward - if isinstance(config.qkv_weight_attrs, (list, tuple)): - self.num_layers = len(config.qkv_weight_attrs) + self.num_layers = config.num_layers assert self.num_layers > 0 + if isinstance(config.qkv_weight_attrs, (list, tuple)): + assert self.num_layers == len(config.qkv_weight_attrs) self.weight_dtype = self._dtype self.create_params_type = self.get_weight_create_dype() @@ -297,7 +298,7 @@ def __init__(self, config: FusedMultiTransformerConfig): dtype=self._norm_weight_dtype, ) - self.get_weight_shape(num_heads, dim_feedforward, config) + self.init_weight_shape(num_heads, dim_feedforward, config) qkv_weight = self.create_parameter( shape=self.qkv_weight_shape, @@ -437,7 +438,7 @@ def _add_parameter(self, param): assert param.name not in self._parameters self._parameters[param.name] = param - def get_weight_shape(self, num_heads, dim_feedforward, config): + def init_weight_shape(self, num_heads, dim_feedforward, config): self.qkv_weight_shape = ( [3 * num_heads * self.head_dim, self.embed_dim] if config.trans_qkvw @@ -816,8 +817,8 @@ def __init__(self, config: FusedMultiTransformerConfig): def get_weight_create_dype(self): return "int8" # If use weightonly int4, params dtype is int8, and one of the dimension will be half. - def get_weight_shape(self, num_heads, dim_feedforward, config): - super().get_weight_shape(num_heads, dim_feedforward, config) + def init_weight_shape(self, num_heads, dim_feedforward, config): + super().init_weight_shape(num_heads, dim_feedforward, config) self.linear_weight_shape = [self.embed_dim, num_heads * self.head_dim] self.ffn1_weight_shape = (