Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Extend the custom fp8 accumulate dtype in non-jit scenarios #3827

Merged
merged 2 commits into from
Apr 25, 2024

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Apr 4, 2024

This PR extends the custom fp8 dtype in the non-jit scenarios. The custom fp8 dtype is for the grad accumulation where we can do a max op rather than the default add op to accumulate the grads. This is designed for the fp8 parameters whose grads are their new values and we only want to pick the max among the grads of the sub-tensors.

Previously, this custom dtype has to be in the jit scope and this PR relaxes that restriction.

Note, this PR depends on jax-ml/jax#20266

cc. @nluehr @mingxu1067

@kaixih kaixih force-pushed the fp8_meta_custom_dtype_non_jit branch 2 times, most recently from 4b8accf to f2070a1 Compare April 8, 2024 17:53
@kaixih
Copy link
Contributor Author

kaixih commented Apr 8, 2024

This JAX PR has been merged and so I rebased my change to import the newly-introduced earray. @mattjj Can you take a look?

@kaixih kaixih force-pushed the fp8_meta_custom_dtype_non_jit branch from f2070a1 to 0b59501 Compare April 8, 2024 18:07
@IvyZX
Copy link
Collaborator

IvyZX commented Apr 22, 2024

Hi thanks for adding this change.
Looks like JAX hasn't released earray, so you would need to add a check when importing it, if you want this PR to be merged in early. You might also need to run on JAX nightly to use the JAX-side change.
Other than that this PR looks good to me.

@kaixih
Copy link
Contributor Author

kaixih commented Apr 23, 2024

@IvyZX Thanks for the review. I made one more change to make sure it compatible with the current JAX version that has no earray. PTAL.

@copybara-service copybara-service bot merged commit 7818932 into google:main Apr 25, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants