Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding advanced interface #68

Merged
merged 16 commits into from
Mar 15, 2024
Merged

Adding advanced interface #68

merged 16 commits into from
Mar 15, 2024

Conversation

dfm
Copy link
Collaborator

@dfm dfm commented Mar 13, 2024

This is still very much a work in progress, but it should help with #66 and #67. I'll request feedback ASAP!

@dfm dfm marked this pull request as ready for review March 14, 2024 17:58
@dfm dfm requested a review from lgarrison March 14, 2024 17:58
@dfm
Copy link
Collaborator Author

dfm commented Mar 14, 2024

@lgarrison — Here's my first pass at all this. Here's how we would provide the configuration relevant for #67:

from jax_finufft import nufft2, options

opts = options.NestedOpts(
  type1=options.Opts(gpu_method=1),
  type2=options.Opts(gpu_method=2),
)

nufft2(..., opts=opts)

or, equivalently in this case:

from jax_finufft import nufft2, options

opts = options.NestedOpts(
  forward=options.Opts(gpu_method=2),
  backward=options.Opts(gpu_method=1),
)

nufft2(..., opts=opts)

It's not all that ergonomic, but I think it's a decent start!

@dfm
Copy link
Collaborator Author

dfm commented Mar 14, 2024

I've also done something with the imports to break the CUDA compilation. I think it has something to do with jax_finufft_gpu.h being included twice. We probably need to move the descriptor definition to a separate header.

@lgarrison
Copy link
Member

lgarrison commented Mar 15, 2024

The immediate CUDA compilation error is just that there's no declaration of the default_opts<T> function visible to jax_finufft_gpu.cc. That declaration lives in lib/jax_finufft_gpu.h, but that file can't be included as a header in multiple compilation units because it contains function definitions as well as declarations. To fix this, I did the usual thing of splitting the declarations out into a header and putting the definitions in a source file. I called them cufinufft_wrapper.h and cufinufft_wrapper.cc, since most of that file is about giving the cufinufft functions C++ wrappers. But if we don't want to fix it this way for any reason, let me know!

Some CUDA tests fail locally with a CUDA illegal memory access. Not yet sure if it's a problem with the opts, or this header refactoring.

@lgarrison
Copy link
Member

The problem was indeed with the header refactoring. The y_index and z_index functions have generic templated definitions in the header, as well as template specializations in the source file. But if the specializations aren't declared in the header, then the compiler won't know it needs to look for the specializations and will just use the generic version.

I don't like my solution, it feels fragile to me! It's too easy to write a specialization in the source file that gets silently ignored. Not sure if there's a better pattern we should be using here.

@dfm
Copy link
Collaborator Author

dfm commented Mar 15, 2024

Thanks for taking this down @lgarrison!! I'll take a look this afternoon.

@lgarrison
Copy link
Member

I confirm the opts are working for me and fix the performance issue from #67.

@dfm
Copy link
Collaborator Author

dfm commented Mar 15, 2024

Thanks @lgarrison! I think that the approach you came up with here is totally fine. I agree that it's not very elegant, but I think we should just roll with it and revisit only if we need to later. It's possible that the whole library could benefit from some refactoring, but let's not let that get in the way of merging this. With that in mind, I'm going to merge this now!

This fixes #67, but let's leave #66 open until we add info to the README.

@dfm dfm linked an issue Mar 15, 2024 that may be closed by this pull request
@dfm dfm merged commit ef69daa into main Mar 15, 2024
5 checks passed
@dfm dfm deleted the advanced branch March 15, 2024 19:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Computing gradient of a nufft2 costs 13x more than the nufft2 alone on GPU
2 participants