Skip to content

Commit

Permalink
Merge pull request #23415 from kaixih:key_value_seq_lengths
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673409724
  • Loading branch information
Google-ML-Automation committed Sep 11, 2024
2 parents ea68f45 + 2d2cbbc commit e869a9d
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,10 +786,14 @@ def _get_causal_mask(T, S):
return mask[None, None, :, :]

def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen):
q_indices = jnp.arange(0, T)[None, :, None]
kv_indices = jnp.arange(0, S)[None, None, :]
q_mask = q_indices < q_seqlen[:, None, None]
kv_mask = kv_indices < kv_seqlen[:, None, None]
q_mask = True
kv_mask = True
if q_seqlen is not None:
q_indices = jnp.arange(0, T)[None, :, None]
q_mask = q_indices < q_seqlen[:, None, None]
if kv_seqlen is not None:
kv_indices = jnp.arange(0, S)[None, None, :]
kv_mask = kv_indices < kv_seqlen[:, None, None]
mask = jnp.logical_and(q_mask, kv_mask)
return mask[:, None, :, :]

Expand All @@ -813,7 +817,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
mask = _get_causal_mask(T, S)
combined_mask = jnp.logical_and(combined_mask, mask)

if q_seqlen is not None and kv_seqlen is not None:
if q_seqlen is not None or kv_seqlen is not None:
mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen)
combined_mask = jnp.logical_and(combined_mask, mask)

Expand Down Expand Up @@ -1001,12 +1005,22 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
kv_seqlen=key_value_seq_lengths,
)
case 'cudnn':
use_padding = (
query_seq_lengths is not None or key_value_seq_lengths is not None
)
if use_padding:
if query_seq_lengths is None:
T = query_arr.shape[1]
query_seq_lengths = jnp.full((B,), T, dtype=jnp.int32)
if key_value_seq_lengths is None:
key_value_seq_lengths = jnp.full((B,), S, dtype=jnp.int32)

mask_type = MaskType.NO_MASK
if query_seq_lengths is not None and is_causal:
if use_padding and is_causal:
mask_type = MaskType.PADDING_CAUSAL
elif is_causal:
mask_type = MaskType.CAUSAL
elif query_seq_lengths is not None:
elif use_padding:
mask_type = MaskType.PADDING
out = cudnn_dot_product_attention(
query_arr, key_arr, value_arr, bias, mask, query_seq_lengths,
Expand Down

0 comments on commit e869a9d

Please sign in to comment.