Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing deprecated XLA translation rules with MLIR lowering rules #21

Merged
merged 3 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ c = nufft2(f, x, y, z) # 3D
This package, developed by Dan Foreman-Mackey is licensed under the Apache
License, Version 2.0, with the following copyright:

Copyright 2021 The Simons Foundation, Inc.
Copyright 2021, 2022, 2023 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the
[FINUFFT docs](https://finufft.readthedocs.io/en/latest/refs.html).
2 changes: 1 addition & 1 deletion src/jax_finufft/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The Simons Foundation, Inc.
# Copyright 2021, 2022, 2023 The Simons Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
107 changes: 107 additions & 0 deletions src/jax_finufft/lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import numpy as np
from jax.interpreters.mlir import ir
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call, hlo_const

from . import jax_finufft_cpu

try:
from . import jax_finufft_gpu

for _name, _value in jax_finufft_gpu.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
jax_finufft_gpu = None

for _name, _value in jax_finufft_cpu.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")


def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]


def lowering(platform, ctx, source, *points, output_shape, iflag, eps):
del ctx

if platform not in ["cpu", "gpu"]:
raise ValueError(f"Unrecognized platform '{platform}'")

if platform == "gpu" and jax_finufft_gpu is None:
raise ValueError("jax-finufft was not compiled with GPU support")

ndim = len(points)
assert 1 <= ndim <= 3
if platform == "gpu" and ndim == 1:
raise ValueError("1-D transforms are not yet supported on the GPU")

source_type = ir.RankedTensorType(source.type)
points_type = [ir.RankedTensorType(x.type) for x in points]

# Check supported and consistent dtypes
f32 = ir.F32Type.get()
f64 = ir.F64Type.get()
source_dtype = source_type.element_type
single = source_dtype == ir.ComplexType.get(f32) and all(
x.element_type == f32 for x in points_type
)
double = source_dtype == ir.ComplexType.get(f64) and all(
x.element_type == f64 for x in points_type
)
assert single or double
suffix = "f" if single else ""

# Check shapes
source_shape = source_type.shape
points_shape = tuple(x.shape for x in points_type)
n_tot = source_shape[0]
n_transf = source_shape[1]
n_j = points_shape[0][1]
if output_shape is None:
op_name = f"nufft{ndim}d2{suffix}".encode("ascii")
n_k = np.array(source_shape[2:], dtype=np.int64)
full_output_shape = tuple(source_shape[:2]) + (n_j,)
else:
op_name = f"nufft{ndim}d1{suffix}".encode("ascii")
n_k = np.array(output_shape, dtype=np.int64)
full_output_shape = tuple(source_shape[:2]) + tuple(output_shape)

# The backend expects the output shape in Fortran order so we'll just
# fake it here, by sending in n_k and x in the reverse order.
n_k_full = np.zeros(3, dtype=np.int64)
n_k_full[:ndim] = n_k[::-1]

# Build the descriptor containing the transform parameters
opaque = getattr(jax_finufft_cpu, f"build_descriptor{suffix}")(
eps, iflag, n_tot, n_transf, n_j, *n_k_full
)

if platform == "cpu":
opaque_arg = hlo_const(np.frombuffer(opaque, dtype=np.uint8))
opaque_shape = ir.RankedTensorType(opaque_arg.type).shape
return custom_call(
op_name,
result_types=[
ir.RankedTensorType.get(full_output_shape, source_type.element_type)
],
# Reverse points because backend uses Fortran order
operands=[opaque_arg, source, *points[::-1]],
backend_config=opaque,
operand_layouts=default_layouts(
opaque_shape, source_shape, *points_shape[::-1]
),
result_layouts=default_layouts(full_output_shape),
).results

else:
return custom_call(
op_name,
result_types=[
ir.RankedTensorType.get(full_output_shape, source_type.element_type)
],
# Reverse points because backend uses Fortran order
operands=[source, *points[::-1]],
backend_config=opaque,
operand_layouts=default_layouts(source_shape, *points_shape[::-1]),
result_layouts=default_layouts(full_output_shape),
).results
24 changes: 8 additions & 16 deletions src/jax_finufft/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from jax import core
from jax import jit
from jax import numpy as jnp
from jax.interpreters import ad, batching, xla
from jax.interpreters import ad, batching, xla, mlir

from . import shapes, translation
from . import shapes, lowering


@partial(jit, static_argnums=(0,), static_argnames=("iflag", "eps"))
Expand Down Expand Up @@ -181,13 +181,9 @@ def batch(args, axes, *, output_shape, **kwargs):
nufft1_p = core.Primitive("nufft1")
nufft1_p.def_impl(partial(xla.apply_primitive, nufft1_p))
nufft1_p.def_abstract_eval(shapes.abstract_eval)
xla.register_translation(
nufft1_p, partial(translation.translation_rule, "cpu"), platform="cpu"
)
if translation.jax_finufft_gpu is not None:
xla.register_translation(
nufft1_p, partial(translation.translation_rule, "gpu"), platform="cuda"
)
mlir.register_lowering(nufft1_p, partial(lowering.lowering, "cpu"), platform="cpu")
if lowering.jax_finufft_gpu is not None:
mlir.register_lowering(nufft1_p, partial(lowering.lowering, "gpu"), platform="gpu")
ad.primitive_jvps[nufft1_p] = partial(jvp, nufft1_p)
ad.primitive_transposes[nufft1_p] = transpose
batching.primitive_batchers[nufft1_p] = batch
Expand All @@ -196,13 +192,9 @@ def batch(args, axes, *, output_shape, **kwargs):
nufft2_p = core.Primitive("nufft2")
nufft2_p.def_impl(partial(xla.apply_primitive, nufft2_p))
nufft2_p.def_abstract_eval(shapes.abstract_eval)
xla.register_translation(
nufft2_p, partial(translation.translation_rule, "cpu"), platform="cpu"
)
if translation.jax_finufft_gpu is not None:
xla.register_translation(
nufft2_p, partial(translation.translation_rule, "gpu"), platform="cuda"
)
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "cpu"), platform="cpu")
if lowering.jax_finufft_gpu is not None:
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "gpu"), platform="gpu")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lgarrison — Doh! I think I found the issue lol

ad.primitive_jvps[nufft2_p] = partial(jvp, nufft2_p)
ad.primitive_transposes[nufft2_p] = transpose
batching.primitive_batchers[nufft2_p] = batch
128 changes: 0 additions & 128 deletions src/jax_finufft/translation.py

This file was deleted.

Loading