diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index fdc9c5f666d7..3d7010599377 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -584,6 +584,13 @@ def compute_qkv(self, src, residual_input, i): qkv_out = self.compute_qkv_linear(ln_out, i) return qkv_out, residual_input + def compute_max_len(self, seq_lens_encoder, seq_lens_decoder, cum_offsets): + if seq_lens_encoder is None or seq_lens_decoder is None or cum_offsets is None: + return None, None + return paddle.incubate.nn.functional.blha_get_max_len( + seq_lens_encoder, seq_lens_decoder, cum_offsets # cum_offsets.shape[0] used as bsz + ) + def compute_fmha( self, qkv_out, @@ -816,6 +823,12 @@ def forward( assert self.num_layers == len(self.qkv_weights) + max_enc_len_this_time, max_dec_len_this_time = self.compute_max_len( + kwargs.get("seq_lens_encoder", None), kwargs.get("seq_lens_decoder", None), cum_offsets + ) + kwargs["max_enc_len_this_time"] = max_enc_len_this_time + kwargs["max_dec_len_this_time"] = max_dec_len_this_time + residual_input = src for i in range(self.num_layers): qkv_out, residual_input = self.compute_qkv(src, residual_input, i) @@ -1385,6 +1398,8 @@ def compute_attn( None, # qkv_bias None, # out_shifts None, # out_smooths + kwargs.get("max_enc_len_this_time", None), + kwargs.get("max_dec_len_this_time", None), rotary_embs, attn_mask, kwargs.get("tgt_mask", None), @@ -1471,6 +1486,8 @@ def compute_attn( self.qkv_biases[i] if len(self.qkv_biases) > 0 else None, self.linear_shifts[i] if len(self.linear_shifts) > 0 else None, self.linear_smooths[i] if len(self.linear_smooths) > 0 else None, + kwargs.get("max_enc_len_this_time", None), + kwargs.get("max_dec_len_this_time", None), rotary_embs, attn_mask, kwargs.get("tgt_mask", None),