-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replacing deprecated XLA translation rules with MLIR lowering rules #21
Conversation
@lgarrison — I don't have a mechanism for testing the GPU ops. Would you be able to run a quick check? |
No problem! Unfortunately it looks like something did break on the GPU side. Here's the first failing test:
None of the GPU tests pass; all of the CPU ones do. Something nasty must be happening because the GPU tests eventually crash with what I think is a bad address on the GPU side. |
OK - thanks! I will look into what's happening here. |
@@ -194,7 +194,7 @@ def batch(args, axes, *, output_shape, **kwargs): | |||
nufft2_p.def_abstract_eval(shapes.abstract_eval) | |||
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "cpu"), platform="cpu") | |||
if lowering.jax_finufft_gpu is not None: | |||
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "cpu"), platform="gpu") | |||
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "gpu"), platform="gpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lgarrison — Doh! I think I found the issue lol
Looks like that was it, the GPU tests pass now! |
The old interface used to define XLA translation rules have been deprecated in favor of these MLIR lowering rules. This PR replaces the old implementation.
Thepre-commit
rules were also slightly updated and run on the full code base, so this is a slightly larger diff than I was planning.