Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.nn.dot_product_attention does not respect key_value_seq_lengths #23349

Open
danjenson opened this issue Aug 30, 2024 · 5 comments
Open

jax.nn.dot_product_attention does not respect key_value_seq_lengths #23349

danjenson opened this issue Aug 30, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@danjenson
Copy link

danjenson commented Aug 30, 2024

Description

Perhaps I am using this function incorrectly, but I get data leaks when using key_value_seq_lengths. It appears as though both the xla and cudnn implementations in jax nightly do not support this argument. Here is some reproducible code:

#!/usr/bin/env python3
import jax.numpy as jnp
from jax import random, nn

B, L, H, D = 8, 128, 4, 64
rng = random.key(42)
x = random.normal(rng, (B, L, H, D // H), dtype=jnp.bfloat16)
valid_lens = jnp.array([24, 125, 53, 28, 77, 96, 13, 114], jnp.int32)


def vanilla_attention(qs, ks, vs, valid_lens):
    scores = jnp.einsum("BQHD,BKHD->BHQK", qs, ks) / jnp.sqrt(D // H)
    if valid_lens is not None:
        mask = jnp.arange(L) < valid_lens[:, None]
        mask = mask[:, None, None, :]  # broadcast across H, Q in [B, H, Q, K]
        scores = jnp.where(mask, scores, -jnp.inf)
    attn = nn.softmax(scores, axis=-1)
    return jnp.einsum("BHQK,BKHD->BQHD", attn, vs).reshape(B, L, D)


def xla_attention(qs, ks, vs, valid_lens):
    ctx = nn.dot_product_attention(
        qs, ks, vs, key_value_seq_lengths=valid_lens, implementation="xla"
    )
    return ctx.reshape(B, L, D)


def cudnn_attention(qs, ks, vs, valid_lens):
    ctx = nn.dot_product_attention(
        qs, ks, vs, key_value_seq_lengths=valid_lens, implementation="cudnn"
    )
    return ctx.reshape(B, L, D)


van_attn = vanilla_attention(x, x, x, valid_lens)
xla_attn = xla_attention(x, x, x, valid_lens)
cud_attn = cudnn_attention(x, x, x, valid_lens)
print(jnp.allclose(van_attn, xla_attn, rtol=1.0, atol=1.0))  # False
print(jnp.allclose(van_attn, cud_attn, rtol=1.0, atol=1.0))  # False
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01))  # True

van_attn = vanilla_attention(x, x, x, None)
xla_attn = xla_attention(x, x, x, None)
cud_attn = cudnn_attention(x, x, x, None)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01))  # True

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.32.dev20240830
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.12.2 (main, Mar  2 2024, 09:51:01) [GCC 13.2.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='ghost', release='6.6.47_1', version='#1 SMP PREEMPT_DYNAMIC Mon Aug 19 16:42:31 UTC 2024', machine='x86_64')


$ nvidia-smi
Fri Aug 30 18:11:31 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0 Off |                  Off |
| 32%   39C    P2             78W /  480W |     393MiB /  24564MiB |      3%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     28449      C   ...ions/3.12.2/envs/jax/bin/python3.12        386MiB |
+-----------------------------------------------------------------------------------------+
@danjenson danjenson added the bug Something isn't working label Aug 30, 2024
@superbobry
Copy link
Collaborator

@kaixih PTAL.

@kaixih
Copy link
Contributor

kaixih commented Sep 3, 2024

I just created a PR to fix this issue. Basically, the current API requires both query_seq_lengths and key_value_seq_lengths. This PR relaxes it. Can you take a look at it to see if it works?

From user side, you can also try explicitly provide the query_seq_lengths with a tensor filled with max seq lengths.

@danjenson
Copy link
Author

The following works when I provide the max len for query lengths:

#!/usr/bin/env python3
import jax.numpy as jnp
from jax import random, nn

B, L, H, D = 8, 128, 4, 64
rng = random.key(42)
x = random.normal(rng, (B, L, H, D // H), dtype=jnp.bfloat16)
valid_lens = jnp.array([24, 125, 53, 28, 77, 96, 13, 114], jnp.int32)


def vanilla_attention(qs, ks, vs, valid_lens):
    scores = jnp.einsum("BQHD,BKHD->BHQK", qs, ks) / jnp.sqrt(D // H)
    if valid_lens is not None:
        mask = jnp.arange(L) < valid_lens[:, None]
        mask = mask[:, None, None, :]  # broadcast across H, Q in [B, H, Q, K]
        scores = jnp.where(mask, scores, -jnp.inf)
    attn = nn.softmax(scores, axis=-1)
    return jnp.einsum("BHQK,BKHD->BQHD", attn, vs).reshape(B, L, D)


def xla_attention(qs, ks, vs, valid_lens):
    if valid_lens is None:
        valid_lens = jnp.repeat(L, B)
    ctx = nn.dot_product_attention(
        qs,
        ks,
        vs,
        query_seq_lengths=jnp.repeat(L, B),
        key_value_seq_lengths=valid_lens,
        implementation="xla",
    )
    return ctx.reshape(B, L, D)


def cudnn_attention(qs, ks, vs, valid_lens):
    if valid_lens is None:
        valid_lens = jnp.repeat(L, B)
    ctx = nn.dot_product_attention(
        qs,
        ks,
        vs,
        query_seq_lengths=jnp.repeat(L, B),
        key_value_seq_lengths=valid_lens,
        implementation="cudnn",
    )
    return ctx.reshape(B, L, D)


van_attn = vanilla_attention(x, x, x, valid_lens)
xla_attn = xla_attention(x, x, x, valid_lens)
cud_attn = cudnn_attention(x, x, x, valid_lens)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01))  # True

van_attn = vanilla_attention(x, x, x, None)
xla_attn = xla_attention(x, x, x, None)
cud_attn = cudnn_attention(x, x, x, None)
print(jnp.allclose(van_attn, xla_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(van_attn, cud_attn, rtol=0.01, atol=0.01))  # True
print(jnp.allclose(cud_attn, xla_attn, rtol=0.01, atol=0.01))  # True

@kaixih
Copy link
Contributor

kaixih commented Sep 4, 2024

@danjenson Can we know if it is a typical use case for you to only provide the kv_seq_lengths?

@danjenson
Copy link
Author

Constantly -- usually I want an answer to every "query" but each query can only use specific data/keys when answering that question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants