Skip to content

Commit

Permalink
modify block_multihead_attention api (#8456)
Browse files Browse the repository at this point in the history
* modify block_multihead_attention api

* add param to blha

* modify fused_transformer_layers

* fix bug
  • Loading branch information
ming1753 authored May 28, 2024
1 parent 85ba573 commit c1cfe63
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,13 @@ def compute_qkv(self, src, residual_input, i):
qkv_out = self.compute_qkv_linear(ln_out, i)
return qkv_out, residual_input

def compute_max_len(self, seq_lens_encoder, seq_lens_decoder, cum_offsets):
if seq_lens_encoder is None or seq_lens_decoder is None or cum_offsets is None:
return None, None
return paddle.incubate.nn.functional.blha_get_max_len(
seq_lens_encoder, seq_lens_decoder, cum_offsets # cum_offsets.shape[0] used as bsz
)

def compute_fmha(
self,
qkv_out,
Expand Down Expand Up @@ -816,6 +823,12 @@ def forward(

assert self.num_layers == len(self.qkv_weights)

max_enc_len_this_time, max_dec_len_this_time = self.compute_max_len(
kwargs.get("seq_lens_encoder", None), kwargs.get("seq_lens_decoder", None), cum_offsets
)
kwargs["max_enc_len_this_time"] = max_enc_len_this_time
kwargs["max_dec_len_this_time"] = max_dec_len_this_time

residual_input = src
for i in range(self.num_layers):
qkv_out, residual_input = self.compute_qkv(src, residual_input, i)
Expand Down Expand Up @@ -1385,6 +1398,8 @@ def compute_attn(
None, # qkv_bias
None, # out_shifts
None, # out_smooths
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
Expand Down Expand Up @@ -1471,6 +1486,8 @@ def compute_attn(
self.qkv_biases[i] if len(self.qkv_biases) > 0 else None,
self.linear_shifts[i] if len(self.linear_shifts) > 0 else None,
self.linear_smooths[i] if len(self.linear_smooths) > 0 else None,
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
Expand Down

0 comments on commit c1cfe63

Please sign in to comment.