-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update vendored finufft and add GPU support (#20)
* starting to add optional cuda support * include dirs for cuda * getting cufinufft to compile * adding first pass at gpu kernels * order of parameters * Minor refactoring to support GPU * Maybe sort-of calling all the right functions? * Add FindCUDAToolkit to cmake to bring in cufft * Trying to hook up Jax CUDA ops * Don't fail on no CUDA * first pass at getting GPU ops to work * Fix GPU tests * vendor: update vendored finufft version to latest and fix deprecations * gpu: use new cufinufft API and change CMake to reflect the fact that the single and double precision interfaces are compiled together now * xla: uppercase CUDA doesn't work anymore, use cuda. GPU tests now run but segfault. * gpu: fix extraneous translation_rule arg * gpu: custom call target registration uses capital CUDA, while translation rules use lowercase cuda, weirdly * gpu: use x64 for some tests that were off by 1.1e-7 * gpu: skip some 1D tests * cmake: get colored output through ninja * gpu: use the CUDA stream provided by JAX * vendor: use lgarrison fork of finufft until flatironinstitute/finufft#330 and flatironinstitute/finufft#354 are merged * Fixes for modern JAX: block until CUDA operations complete. Import jax.experimental. Point to vendored finufft with more fixes. * Probably don't need to sync the stream, JAX ought to do that. But we do need to sync before synchronously destroying resources. * vendor: update finufft --------- Co-authored-by: Dan F-M <foreman.mackey@gmail.com> Co-authored-by: Dan Foreman-Mackey <danfm@nyu.edu>
- Loading branch information
1 parent
ffb336d
commit b2b2cd0
Showing
17 changed files
with
598 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ _skbuild | |
dist | ||
MANIFEST | ||
__pycache__/ | ||
*.egg-info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
[submodule "finufft"] | ||
path = vendor/finufft | ||
url = https://github.com/flatironinstitute/finufft | ||
url = https://github.com/lgarrison/finufft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#ifndef _JAX_FINUFFT_COMMON_H_ | ||
#define _JAX_FINUFFT_COMMON_H_ | ||
|
||
// This descriptor is common to both the jax_finufft and jax_finufft_gpu modules | ||
// We will use the jax_finufft namespace for both | ||
|
||
namespace jax_finufft { | ||
|
||
template <typename T> | ||
struct NufftDescriptor { | ||
T eps; | ||
int iflag; | ||
int64_t n_tot; | ||
int n_transf; | ||
int64_t n_j; | ||
int64_t n_k[3]; | ||
}; | ||
|
||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// This file defines the Python interface to the XLA custom call implemented on the CPU. | ||
// It is exposed as a standard pybind11 module defining "capsule" objects containing our | ||
// method. For simplicity, we export a separate capsule for each supported dtype. | ||
|
||
#include "pybind11_kernel_helpers.h" | ||
#include "kernels.h" | ||
|
||
using namespace jax_finufft; | ||
|
||
namespace { | ||
|
||
pybind11::dict Registrations() { | ||
pybind11::dict dict; | ||
|
||
// TODO: do we prefer to keep these names the same as the CPU version or prefix them with "cu"? | ||
// dict["nufft1d1f"] = encapsulate_function(nufft1d1f); | ||
// dict["nufft1d2f"] = encapsulate_function(nufft1d2f); | ||
dict["nufft2d1f"] = encapsulate_function(nufft2d1f); | ||
dict["nufft2d2f"] = encapsulate_function(nufft2d2f); | ||
dict["nufft3d1f"] = encapsulate_function(nufft3d1f); | ||
dict["nufft3d2f"] = encapsulate_function(nufft3d2f); | ||
|
||
// dict["nufft1d1"] = encapsulate_function(nufft1d1); | ||
// dict["nufft1d2"] = encapsulate_function(nufft1d2); | ||
dict["nufft2d1"] = encapsulate_function(nufft2d1); | ||
dict["nufft2d2"] = encapsulate_function(nufft2d2); | ||
dict["nufft3d1"] = encapsulate_function(nufft3d1); | ||
dict["nufft3d2"] = encapsulate_function(nufft3d2); | ||
|
||
return dict; | ||
} | ||
|
||
PYBIND11_MODULE(jax_finufft_gpu, m) { | ||
m.def("registrations", &Registrations); | ||
} | ||
|
||
} // namespace |
Oops, something went wrong.