-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.nn.dot_product_attention
does not respect key_value_seq_lengths
#23349
Comments
@kaixih PTAL. |
I just created a PR to fix this issue. Basically, the current API requires both From user side, you can also try explicitly provide the |
The following works when I provide the max len for query lengths: #!/usr/bin/env python3
import jax.numpy as jnp
from jax import random, nn
B, L, H, D = 8, 128, 4, 64
rng = random.key(42)
x = random.normal(rng, (B, L, H, D // H), dtype=jnp.bfloat16)
valid_lens = jnp.array([24, 125, 53, 28, 77, 96, 13, 114], jnp.int32)
def vanilla_attention(qs, ks, vs, valid_lens):
scores = jnp.einsum("BQHD,BKHD->BHQK", qs, ks) / jnp.sqrt(D // H)
if valid_lens is not None:
mask = jnp.arange(L) < valid_lens[:, None]
mask = mask[:, None, None, :] # broadcast across H, Q in [B, H, Q, K]
scores = jnp.where(mask, scores, -jnp.inf)
attn = nn.softmax(scores, axis=-1)
return jnp.einsum("BHQK,BKHD->BQHD", attn, vs).reshape(B, L, D)
def xla_attention(qs, ks, vs, valid_lens):
if valid_lens is None:
valid_lens = jnp.repeat(L, B)
ctx = nn.dot_product_attention(
qs,
ks,
vs,
query_seq_lengths=jnp.repeat(L, B),
key_value_seq_lengths=valid_lens,
implementation="xla",
)
return ctx.reshape(B, L, D)
def cudnn_attention(qs, ks, vs, valid_lens):
if valid_lens is None:
valid_lens = jnp.repeat(L, B)
ctx = nn.dot_product_attention(
qs,
ks,
vs,
query_seq_lengths=jnp.repeat(L, B),
key_value_seq_lengths=valid_lens,
implementation="cudnn",
)
return ctx.reshape(B, L, D)
van_attn = vanilla_attention(x, x, x, valid_lens)
xla_attn = xla_attention(x, x, x, valid_lens)
cud_attn = cudnn_attention(x, x, x, valid_lens)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01)) # True
van_attn = vanilla_attention(x, x, x, None)
xla_attn = xla_attention(x, x, x, None)
cud_attn = cudnn_attention(x, x, x, None)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01)) # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01)) # True |
@danjenson Can we know if it is a typical use case for you to only provide the kv_seq_lengths? |
Constantly -- usually I want an answer to every "query" but each query can only use specific data/keys when answering that question. |
Description
Perhaps I am using this function incorrectly, but I get data leaks when using
key_value_seq_lengths
. It appears as though both thexla
andcudnn
implementations in jax nightly do not support this argument. Here is some reproducible code:System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: