Skip to content

Commit

Permalink
Updating pre-commit rules (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm authored Oct 30, 2023
1 parent b2b2cd0 commit b9296a9
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 76 deletions.
26 changes: 13 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
rev: "v4.4.0"
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
exclude_types: [json]
- id: debug-statements

- repo: https://github.com/PyCQA/isort
rev: "5.8.0"
hooks:
- id: isort
args: []
additional_dependencies: [toml]

exclude_types: [json, binary]
- repo: https://github.com/psf/black
rev: "20.8b1"
rev: "23.7.0"
hooks:
- id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.285"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: "v17.0.3"
hooks:
- id: black
- id: clang-format
8 changes: 4 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ if (CMAKE_CUDA_COMPILER)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80;90")
endif()

# Find cufft
find_package(CUDAToolkit)

Expand Down Expand Up @@ -99,9 +99,9 @@ if (CMAKE_CUDA_COMPILER)

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/deconvolve_wrapper.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/memtransfer_wrapper.cu

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/spreadinterp.cpp

${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/precision_independent.cu
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/contrib/legendre_rule_fast.cpp
${CMAKE_CURRENT_LIST_DIR}/vendor/finufft/src/cuda/utils.cpp
Expand All @@ -110,7 +110,7 @@ if (CMAKE_CUDA_COMPILER)

add_library(cufinufft STATIC ${CUFINUFFT_SOURCES})
target_include_directories(cufinufft PRIVATE ${CUFINUFFT_INCLUDE_DIRS})

pybind11_add_module(jax_finufft_gpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc
${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu)
Expand Down
4 changes: 2 additions & 2 deletions lib/jax_finufft_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct NufftDescriptor {
int64_t n_j;
int64_t n_k[3];
};
}

} // namespace jax_finufft

#endif
4 changes: 2 additions & 2 deletions lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
// It is exposed as a standard pybind11 module defining "capsule" objects containing our
// method. For simplicity, we export a separate capsule for each supported dtype.

#include "pybind11_kernel_helpers.h"

#include "jax_finufft_cpu.h"

#include "pybind11_kernel_helpers.h"

using namespace jax_finufft;

namespace {
Expand Down
6 changes: 2 additions & 4 deletions lib/jax_finufft_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// It is exposed as a standard pybind11 module defining "capsule" objects containing our
// method. For simplicity, we export a separate capsule for each supported dtype.

#include "pybind11_kernel_helpers.h"
#include "kernels.h"
#include "pybind11_kernel_helpers.h"

using namespace jax_finufft;

Expand All @@ -30,8 +30,6 @@ pybind11::dict Registrations() {
return dict;
}

PYBIND11_MODULE(jax_finufft_gpu, m) {
m.def("registrations", &Registrations);
}
PYBIND11_MODULE(jax_finufft_gpu, m) { m.def("registrations", &Registrations); }

} // namespace
9 changes: 5 additions & 4 deletions lib/jax_finufft_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ template <>
void default_opts<double>(int type, int dim, cufinufft_opts* opts, cudaStream_t stream) {
cufinufft_default_opts(opts);
opts->gpu_stream = stream;

// double precision in 3D blows out shared memory.
// Fall back to a slower, non-shared memory algorithm
// https://github.com/flatironinstitute/cufinufft/issues/58
if(dim > 2){
opts->gpu_method = 1;
if (dim > 2) {
opts->gpu_method = 1;
}
}

Expand All @@ -49,7 +49,8 @@ int makeplan(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, T e
template <>
int makeplan<float>(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, float eps,
typename plan_type<float>::type* plan, cufinufft_opts* opts) {
int64_t tmp_nmodes[3] = {nmodes[0], nmodes[1], nmodes[2]}; // TODO: use const in cufinufftf_makeplan API
int64_t tmp_nmodes[3] = {nmodes[0], nmodes[1],
nmodes[2]}; // TODO: use const in cufinufftf_makeplan API
return cufinufftf_makeplan(type, dim, tmp_nmodes, iflag, ntr, eps, plan, opts);
}

Expand Down
43 changes: 21 additions & 22 deletions lib/kernels.cc.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "jax_finufft_gpu.h"
#include "kernels.h"
#include "kernel_helpers.h"
#include "kernels.h"

namespace jax_finufft {

Expand All @@ -9,18 +9,18 @@ void ThrowIfError(cudaError_t error) {
throw std::runtime_error(cudaGetErrorString(error));
}
}

template <int ndim, typename T>
void run_nufft(int type, const NufftDescriptor<T>* descriptor, T *x, T *y, T *z,
std::complex<T> *c, std::complex<T> *F, cudaStream_t stream) {
void run_nufft(int type, const NufftDescriptor<T> *descriptor, T *x, T *y, T *z,
std::complex<T> *c, std::complex<T> *F, cudaStream_t stream) {
int64_t n_k = 1;
for (int d = 0; d < ndim; ++d) n_k *= descriptor->n_k[d];

cufinufft_opts *opts = new cufinufft_opts;
typename plan_type<T>::type plan;
default_opts<T>(type, ndim, opts, stream);
makeplan<T>(type, ndim, descriptor->n_k, descriptor->iflag,
descriptor->n_transf, descriptor->eps, &plan, opts);
makeplan<T>(type, ndim, descriptor->n_k, descriptor->iflag, descriptor->n_transf,
descriptor->eps, &plan, opts);
for (int64_t index = 0; index < descriptor->n_tot; ++index) {
int64_t j = index * descriptor->n_j * descriptor->n_transf;
int64_t k = index * n_k * descriptor->n_transf;
Expand All @@ -37,11 +37,10 @@ void run_nufft(int type, const NufftDescriptor<T>* descriptor, T *x, T *y, T *z,
delete opts;
}


template <int ndim, typename T>
void nufft1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) {
void nufft1(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
const NufftDescriptor<T> *descriptor = unpack_descriptor<NufftDescriptor<T>>(opaque, opaque_len);

std::complex<T> *c = reinterpret_cast<std::complex<T> *>(buffers[0]);
T *x = reinterpret_cast<T *>(buffers[1]);
T *y = NULL;
Expand All @@ -56,16 +55,16 @@ void nufft1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t
out_dim = 4;
}
std::complex<T> *F = reinterpret_cast<std::complex<T> *>(buffers[out_dim]);

run_nufft<ndim, T>(1, descriptor, x, y, z, c, F, stream);

ThrowIfError(cudaGetLastError());
}

template <int ndim, typename T>
void nufft2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) {
void nufft2(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
const NufftDescriptor<T> *descriptor = unpack_descriptor<NufftDescriptor<T>>(opaque, opaque_len);

std::complex<T> *F = reinterpret_cast<std::complex<T> *>(buffers[0]);
T *x = reinterpret_cast<T *>(buffers[1]);
T *y = NULL;
Expand All @@ -80,42 +79,42 @@ void nufft2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t
out_dim = 4;
}
std::complex<T> *c = reinterpret_cast<std::complex<T> *>(buffers[out_dim]);

run_nufft<ndim, T>(2, descriptor, x, y, z, c, F, stream);

ThrowIfError(cudaGetLastError());
}

void nufft2d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) {
void nufft2d1(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<2, double>(stream, buffers, opaque, opaque_len);
}

void nufft2d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) {
void nufft2d2(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft2<2, double>(stream, buffers, opaque, opaque_len);
}

void nufft3d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) {
void nufft3d1(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<3, double>(stream, buffers, opaque, opaque_len);
}

void nufft3d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) {
void nufft3d2(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft2<3, double>(stream, buffers, opaque, opaque_len);
}

void nufft2d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) {
void nufft2d1f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<2, float>(stream, buffers, opaque, opaque_len);
}

void nufft2d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) {
void nufft2d2f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft2<2, float>(stream, buffers, opaque, opaque_len);
}

void nufft3d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) {
void nufft3d1f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<3, float>(stream, buffers, opaque, opaque_len);
}

void nufft3d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) {
void nufft3d2f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft2<3, float>(stream, buffers, opaque, opaque_len);
}

}
} // namespace jax_finufft
3 changes: 1 addition & 2 deletions lib/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <cstddef>
#include <cstdint>


namespace jax_finufft {

void nufft2d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
Expand All @@ -21,4 +20,4 @@ void nufft3d2f(cudaStream_t stream, void** buffers, const char* opaque, std::siz

} // namespace jax_finufft

#endif
#endif
30 changes: 17 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
[build-system]
requires = [
"pybind11>=2.6",
"scikit-build-core>=0.5",
]
requires = ["pybind11>=2.6", "scikit-build-core>=0.5"]
build-backend = "scikit_build_core.build"

[project]
name = "jax-finufft"
description = "Unofficial JAX bindings for finufft"
readme = "README.md"
authors = [{name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com"}]
authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }]
requires-python = ">=3.7"
license = {file = "LICENSE"}
urls = {Homepage = "https://github.com/dfm/jax-finufft"}
dependencies = [
"jax",
"jaxlib",
]
license = { file = "LICENSE" }
urls = { Homepage = "https://github.com/dfm/jax-finufft" }
dependencies = ["jax", "jaxlib"]
dynamic = ["version"]

[project.optional-dependencies]
Expand All @@ -32,5 +26,15 @@ build-dir = "build/{wheel_tag}"
[tool.setuptools_scm]
version_file = "src/jax_finufft/jax_finufft_version.py"

[tool.isort]
profile = "black"
[tool.black]
target-version = ["py39"]
line-length = 88

[tool.ruff]
line-length = 88
target-version = "py39"
exclude = []

[tool.ruff.isort]
known-first-party = ["jax-finufft"]
combine-as-imports = true
8 changes: 6 additions & 2 deletions src/jax_finufft/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def batch(args, axes, *, output_shape, **kwargs):
nufft1_p = core.Primitive("nufft1")
nufft1_p.def_impl(partial(xla.apply_primitive, nufft1_p))
nufft1_p.def_abstract_eval(shapes.abstract_eval)
xla.register_translation(nufft1_p, partial(translation.translation_rule, "cpu"), platform="cpu")
xla.register_translation(
nufft1_p, partial(translation.translation_rule, "cpu"), platform="cpu"
)
if translation.jax_finufft_gpu is not None:
xla.register_translation(
nufft1_p, partial(translation.translation_rule, "gpu"), platform="cuda"
Expand All @@ -194,7 +196,9 @@ def batch(args, axes, *, output_shape, **kwargs):
nufft2_p = core.Primitive("nufft2")
nufft2_p.def_impl(partial(xla.apply_primitive, nufft2_p))
nufft2_p.def_abstract_eval(shapes.abstract_eval)
xla.register_translation(nufft2_p, partial(translation.translation_rule, "cpu"), platform="cpu")
xla.register_translation(
nufft2_p, partial(translation.translation_rule, "cpu"), platform="cpu"
)
if translation.jax_finufft_gpu is not None:
xla.register_translation(
nufft2_p, partial(translation.translation_rule, "gpu"), platform="cuda"
Expand Down
2 changes: 1 addition & 1 deletion src/jax_finufft/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def translation_rule(
):
if platform == "gpu" and jax_finufft_gpu is None:
raise ValueError("jax-finufft was not compiled with GPU support")

ndim = len(points)
assert 1 <= ndim <= 3
if platform == "gpu" and ndim == 1:
Expand Down
18 changes: 11 additions & 7 deletions tests/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag):
num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim))
ks = [np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) for n in num_uniform]

x = random.uniform(-np.pi, np.pi, size=(ndim,num_nonnuniform)).astype(dtype)
x = random.uniform(-np.pi, np.pi, size=(ndim, num_nonnuniform)).astype(dtype)
c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform)
c = c.astype(cdtype)
f_expect = np.zeros(num_uniform, dtype=cdtype)
Expand Down Expand Up @@ -115,9 +115,11 @@ def test_nufft1_grad(ndim, num_nonnuniform, num_uniform, iflag):
func = partial(nufft1, num_uniform, eps=eps, iflag=iflag)
check_grads(func, (c, *x), 1, modes=("fwd", "rev"))

scalar_func = lambda *args: jnp.linalg.norm(func(*args))
def scalar_func(*args):
return jnp.linalg.norm(func(*args))

expect = jax.grad(scalar_func, argnums=tuple(range(len(x) + 1)))(c, *x)
for (n, g) in enumerate(expect):
for n, g in enumerate(expect):
np.testing.assert_allclose(jax.grad(scalar_func, argnums=(n,))(c, *x)[0], g)


Expand Down Expand Up @@ -148,9 +150,11 @@ def test_nufft2_grad(ndim, num_nonnuniform, num_uniform, iflag):
func = partial(nufft2, eps=eps, iflag=iflag)
check_grads(func, (f, *x), 1, modes=("fwd", "rev"))

scalar_func = lambda *args: jnp.linalg.norm(func(*args))
def scalar_func(*args):
return jnp.linalg.norm(func(*args))

expect = jax.grad(scalar_func, argnums=tuple(range(len(x) + 1)))(f, *x)
for (n, g) in enumerate(expect):
for n, g in enumerate(expect):
np.testing.assert_allclose(jax.grad(scalar_func, argnums=(n,))(f, *x)[0], g)


Expand Down Expand Up @@ -260,7 +264,7 @@ def test_multi_transform():
# TODO: is there a 2D or 3D version of this test?
if jax.default_backend() != "cpu":
pytest.skip("1D transforms not implemented on GPU")

random = np.random.default_rng(314)

n_tot, n_tr, n_j, n_k = 4, 10, 100, 12
Expand All @@ -280,7 +284,7 @@ def test_multi_transform():
def test_issue14():
if jax.default_backend() != "cpu":
pytest.skip("1D transforms not implemented on GPU")

M = 100
N = 200

Expand Down

0 comments on commit b9296a9

Please sign in to comment.