diff --git a/examples/large_models/tp_llama/llama2.py b/examples/large_models/tp_llama/llama2.py index f30930548e..a3a8883f3c 100644 --- a/examples/large_models/tp_llama/llama2.py +++ b/examples/large_models/tp_llama/llama2.py @@ -206,7 +206,8 @@ def forward( keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) #calling PT SDPA to enable using Flash Attention 2 and Xformer memory efficient kernels. - output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, is_causal=True) + output = torch.nn.functional.scaled_dot_product_attention(xq.transpose(1,2), keys.transpose(1,2), values.transpose(1,2), attn_mask=mask, dropout_p=0.0, is_causal=False) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) return self.wo(output)