Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Relax the requirement for providing both query_seq_lengths and key_value_seq_lengths #23415

Merged
merged 1 commit into from
Sep 11, 2024

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Sep 3, 2024

This PR addresses an issue where the original nn.dot_product_attention required both query_seq_lengths and key_value_seq_lengths. With this update, users can now provide only one of these lengths, thereby relaxing the previous requirement.

This is motivated by this issue.

cc. @superbobry

@kaixih
Copy link
Contributor Author

kaixih commented Sep 3, 2024

Also, @sbodenstein for review.

Copy link
Collaborator

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a regression test for #23349, please?

kv_indices = jnp.arange(0, S)[None, None, :]
q_mask = q_indices < q_seqlen[:, None, None]
kv_mask = kv_indices < kv_seqlen[:, None, None]
q_mask = jnp.array(True, dtype=jnp.bool_)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do jnp.bool_(True) or just True here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

tests/nn_test.py Outdated
@@ -122,7 +122,6 @@ def testDotProductAttentionMask(self, mask_mode):

is_causal = 'causal' in mask_mode
if 'padding' in mask_mode:
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@sbodenstein
Copy link
Contributor

LGTM

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Sep 11, 2024
@copybara-service copybara-service bot merged commit e869a9d into jax-ml:main Sep 11, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants