Skip to content

Commit

Permalink
Update vendored finufft and add GPU support (#20)
Browse files Browse the repository at this point in the history
* starting to add optional cuda support

* include dirs for cuda

* getting cufinufft to compile

* adding first pass at gpu kernels

* order of parameters

* Minor refactoring to support GPU

* Maybe sort-of calling all the right functions?

* Add FindCUDAToolkit to cmake to bring in cufft

* Trying to hook up Jax CUDA ops

* Don't fail on no CUDA

* first pass at getting GPU ops to work

* Fix GPU tests

* vendor: update vendored finufft version to latest and fix deprecations

* gpu: use new cufinufft API and change CMake to reflect the fact that the single and double precision interfaces are compiled together now

* xla: uppercase CUDA doesn't work anymore, use cuda. GPU tests now run but segfault.

* gpu: fix extraneous translation_rule arg

* gpu: custom call target registration uses capital CUDA, while translation rules use lowercase cuda, weirdly

* gpu: use x64 for some tests that were off by 1.1e-7

* gpu: skip some 1D tests

* cmake: get colored output through ninja

* gpu: use the CUDA stream provided by JAX

* vendor: use lgarrison fork of finufft until flatironinstitute/finufft#330 and flatironinstitute/finufft#354 are merged

* Fixes for modern JAX: block until CUDA operations complete. Import jax.experimental. Point to vendored finufft with more fixes.

* Probably don't need to sync the stream, JAX ought to do that. But we do need to sync before synchronously destroying resources.

* vendor: update finufft

---------

Co-authored-by: Dan F-M <foreman.mackey@gmail.com>
Co-authored-by: Dan Foreman-Mackey <danfm@nyu.edu>
  • Loading branch information
3 people authored Oct 30, 2023
1 parent ffb336d commit b2b2cd0
Show file tree
Hide file tree
Showing 17 changed files with 598 additions and 122 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ _skbuild
dist
MANIFEST
__pycache__/
*.egg-info
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "finufft"]
path = vendor/finufft
url = https://github.com/flatironinstitute/finufft
url = https://github.com/lgarrison/finufft
92 changes: 81 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ cmake_minimum_required(VERSION 3.15)
project(${SKBUILD_PROJECT_NAME} LANGUAGES C CXX)
message(STATUS "Using CMake version: " ${CMAKE_VERSION})

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR})
# Add the /cmake directory to the module path so that we can find FFTW
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR}/cmake)

# Handle Python settings passed from scikit-build
if(SKBUILD)
set(Python_EXECUTABLE "${PYTHON_EXECUTABLE}")
set(Python_INCLUDE_DIR "${PYTHON_INCLUDE_DIR}")
Expand All @@ -13,23 +15,29 @@ if(SKBUILD)
OUTPUT_VARIABLE _tmp_dir
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT)
list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}")
else()
find_package(Python COMPONENTS Interpreter Development REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())"
OUTPUT_VARIABLE _tmp_dir
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT)
list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}")
endif()

set(PYBIND11_NEWPYTHON ON)
find_package(pybind11 CONFIG REQUIRED)
find_package(FFTW REQUIRED COMPONENTS FLOAT_LIB DOUBLE_LIB)
link_libraries(${FFTW_FLOAT_LIB} ${FFTW_DOUBLE_LIB})

# Work out compiler flags
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
add_compile_options(-Wall -O3 -funroll-loops)

add_compile_options(-Wall -Wno-unknown-pragmas -O3 -funroll-loops -fdiagnostics-color)
set(FINUFFT_INCLUDE_DIRS
${CMAKE_CURRENT_LIST_DIR}/lib
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include
${FFTW_INCLUDE_DIRS})

message(STATUS "FINUFFT include dirs: " "${FINUFFT_INCLUDE_DIRS}")

# Build single and double point versions of the FINUFFT library
add_library(finufft STATIC
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/spreadinterp.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/utils.cpp
Expand All @@ -45,10 +53,72 @@ add_library(finufft_32 STATIC
target_compile_definitions(finufft_32 PUBLIC SINGLE)
target_include_directories(finufft_32 PRIVATE ${FINUFFT_INCLUDE_DIRS})

pybind11_add_module(jax_finufft
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft.cc
# Build the XLA bindings to those libraries
pybind11_add_module(jax_finufft_cpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cpu.cc
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/utils_precindep.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib/legendre_rule_fast.c)
target_link_libraries(jax_finufft PRIVATE finufft finufft_32)
target_include_directories(jax_finufft PRIVATE ${FINUFFT_INCLUDE_DIRS})
install(TARGETS jax_finufft DESTINATION .)
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib/legendre_rule_fast.cpp)

target_link_libraries(jax_finufft_cpu PRIVATE finufft finufft_32)
target_include_directories(jax_finufft_cpu PRIVATE ${FINUFFT_INCLUDE_DIRS})
install(TARGETS jax_finufft_cpu DESTINATION .)

include(CheckLanguage)
check_language(CUDA)
if (CMAKE_CUDA_COMPILER)
enable_language(CUDA)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80;90")
endif()

# Find cufft
find_package(CUDAToolkit)

set(CUFINUFFT_INCLUDE_DIRS
${CMAKE_CURRENT_LIST_DIR}/lib
# ${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include/cufinufft/contrib/
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include/cufinufft/contrib/cuda_samples
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/include
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})

set(CUFINUFFT_SOURCES
# TODO: 1D not supported via JAX, but needed for compilation
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/1d/spread1d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/1d/interp1d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/1d/cufinufft1d.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/2d/spread2d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/2d/interp2d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/2d/cufinufft2d.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/3d/spread3d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/3d/interp3d_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/3d/cufinufft3d.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/deconvolve_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/memtransfer_wrapper.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/spreadinterp.cpp

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/precision_independent.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib/legendre_rule_fast.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/utils.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/cufinufft.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/common.cu)

add_library(cufinufft STATIC ${CUFINUFFT_SOURCES})
target_include_directories(cufinufft PRIVATE ${CUFINUFFT_INCLUDE_DIRS})

pybind11_add_module(jax_finufft_gpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc
${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu)
target_link_libraries(jax_finufft_gpu PRIVATE cufinufft)
target_link_libraries(jax_finufft_gpu PRIVATE ${CUDA_cufft_LIBRARY} ${CUDA_nvToolsExt_LIBRARY})
target_include_directories(jax_finufft_gpu PRIVATE ${CUFINUFFT_INCLUDE_DIRS})
install(TARGETS jax_finufft_gpu DESTINATION .)

else()
message(STATUS "No CUDA compiler found; GPU support will be disabled")
endif()
File renamed without changes.
21 changes: 21 additions & 0 deletions lib/jax_finufft_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef _JAX_FINUFFT_COMMON_H_
#define _JAX_FINUFFT_COMMON_H_

// This descriptor is common to both the jax_finufft and jax_finufft_gpu modules
// We will use the jax_finufft namespace for both

namespace jax_finufft {

template <typename T>
struct NufftDescriptor {
T eps;
int iflag;
int64_t n_tot;
int n_transf;
int64_t n_j;
int64_t n_k[3];
};

}

#endif
6 changes: 4 additions & 2 deletions lib/jax_finufft.cc → lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "pybind11_kernel_helpers.h"

#include "jax_finufft_cpu.h"

using namespace jax_finufft;

namespace {
Expand All @@ -15,7 +17,7 @@ void run_nufft(int type, void *desc_in, T *x, T *y, T *z, std::complex<T> *c, st
int64_t n_k = 1;
for (int d = 0; d < ndim; ++d) n_k *= descriptor->n_k[d];

nufft_opts *opts = new nufft_opts;
finufft_opts *opts = new finufft_opts;
default_opts<T>(opts);

typename plan_type<T>::type plan;
Expand Down Expand Up @@ -86,7 +88,7 @@ pybind11::dict Registrations() {
return dict;
}

PYBIND11_MODULE(jax_finufft, m) {
PYBIND11_MODULE(jax_finufft_cpu, m) {
m.def("registrations", &Registrations);
m.def("build_descriptorf", &build_descriptor<float>);
m.def("build_descriptor", &build_descriptor<double>);
Expand Down
22 changes: 6 additions & 16 deletions lib/jax_finufft.h → lib/jax_finufft_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@

namespace jax_finufft {

template <typename T>
struct NufftDescriptor {
T eps;
int iflag;
int64_t n_tot;
int n_transf;
int64_t n_j;
int64_t n_k[3];
};

template <typename T>
struct plan_type;

Expand All @@ -31,31 +21,31 @@ struct plan_type<float> {
};

template <typename T>
void default_opts(nufft_opts* opts);
void default_opts(finufft_opts* opts);

template <>
void default_opts<float>(nufft_opts* opts) {
void default_opts<float>(finufft_opts* opts) {
finufftf_default_opts(opts);
}

template <>
void default_opts<double>(nufft_opts* opts) {
void default_opts<double>(finufft_opts* opts) {
finufft_default_opts(opts);
}

template <typename T>
int makeplan(int type, int dim, int64_t* nmodes, int iflag, int ntr, T eps,
typename plan_type<T>::type* plan, nufft_opts* opts);
typename plan_type<T>::type* plan, finufft_opts* opts);

template <>
int makeplan<float>(int type, int dim, int64_t* nmodes, int iflag, int ntr, float eps,
typename plan_type<float>::type* plan, nufft_opts* opts) {
typename plan_type<float>::type* plan, finufft_opts* opts) {
return finufftf_makeplan(type, dim, nmodes, iflag, ntr, eps, plan, opts);
}

template <>
int makeplan<double>(int type, int dim, int64_t* nmodes, int iflag, int ntr, double eps,
typename plan_type<double>::type* plan, nufft_opts* opts) {
typename plan_type<double>::type* plan, finufft_opts* opts) {
return finufft_makeplan(type, dim, nmodes, iflag, ntr, eps, plan, opts);
}

Expand Down
37 changes: 37 additions & 0 deletions lib/jax_finufft_gpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// This file defines the Python interface to the XLA custom call implemented on the CPU.
// It is exposed as a standard pybind11 module defining "capsule" objects containing our
// method. For simplicity, we export a separate capsule for each supported dtype.

#include "pybind11_kernel_helpers.h"
#include "kernels.h"

using namespace jax_finufft;

namespace {

pybind11::dict Registrations() {
pybind11::dict dict;

// TODO: do we prefer to keep these names the same as the CPU version or prefix them with "cu"?
// dict["nufft1d1f"] = encapsulate_function(nufft1d1f);
// dict["nufft1d2f"] = encapsulate_function(nufft1d2f);
dict["nufft2d1f"] = encapsulate_function(nufft2d1f);
dict["nufft2d2f"] = encapsulate_function(nufft2d2f);
dict["nufft3d1f"] = encapsulate_function(nufft3d1f);
dict["nufft3d2f"] = encapsulate_function(nufft3d2f);

// dict["nufft1d1"] = encapsulate_function(nufft1d1);
// dict["nufft1d2"] = encapsulate_function(nufft1d2);
dict["nufft2d1"] = encapsulate_function(nufft2d1);
dict["nufft2d2"] = encapsulate_function(nufft2d2);
dict["nufft3d1"] = encapsulate_function(nufft3d1);
dict["nufft3d2"] = encapsulate_function(nufft3d2);

return dict;
}

PYBIND11_MODULE(jax_finufft_gpu, m) {
m.def("registrations", &Registrations);
}

} // namespace
Loading

0 comments on commit b2b2cd0

Please sign in to comment.