Skip to content

Commit

Permalink
Removing deprecated XLA translation rules
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Oct 30, 2023
1 parent b9296a9 commit 20b6e35
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 144 deletions.
107 changes: 107 additions & 0 deletions src/jax_finufft/lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import numpy as np
from jax.interpreters.mlir import ir
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call, hlo_const

from . import jax_finufft_cpu

try:
from . import jax_finufft_gpu

for _name, _value in jax_finufft_gpu.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
jax_finufft_gpu = None

for _name, _value in jax_finufft_cpu.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")


def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]


def lowering(platform, ctx, source, *points, output_shape, iflag, eps):
del ctx

if platform not in ["cpu", "gpu"]:
raise ValueError(f"Unrecognized platform '{platform}'")

if platform == "gpu" and jax_finufft_gpu is None:
raise ValueError("jax-finufft was not compiled with GPU support")

ndim = len(points)
assert 1 <= ndim <= 3
if platform == "gpu" and ndim == 1:
raise ValueError("1-D transforms are not yet supported on the GPU")

source_type = ir.RankedTensorType(source.type)
points_type = [ir.RankedTensorType(x.type) for x in points]

# Check supported and consistent dtypes
f32 = ir.F32Type.get()
f64 = ir.F64Type.get()
source_dtype = source_type.element_type
single = source_dtype == ir.ComplexType.get(f32) and all(
x.element_type == f32 for x in points_type
)
double = source_dtype == ir.ComplexType.get(f64) and all(
x.element_type == f64 for x in points_type
)
assert single or double
suffix = "f" if single else ""

# Check shapes
source_shape = source_type.shape
points_shape = tuple(x.shape for x in points_type)
n_tot = source_shape[0]
n_transf = source_shape[1]
n_j = points_shape[0][1]
if output_shape is None:
op_name = f"nufft{ndim}d2{suffix}".encode("ascii")
n_k = np.array(source_shape[2:], dtype=np.int64)
full_output_shape = tuple(source_shape[:2]) + (n_j,)
else:
op_name = f"nufft{ndim}d1{suffix}".encode("ascii")
n_k = np.array(output_shape, dtype=np.int64)
full_output_shape = tuple(source_shape[:2]) + tuple(output_shape)

# The backend expects the output shape in Fortran order so we'll just
# fake it here, by sending in n_k and x in the reverse order.
n_k_full = np.zeros(3, dtype=np.int64)
n_k_full[:ndim] = n_k[::-1]

# Build the descriptor containing the transform parameters
opaque = getattr(jax_finufft_cpu, f"build_descriptor{suffix}")(
eps, iflag, n_tot, n_transf, n_j, *n_k_full
)

if platform == "cpu":
opaque_arg = hlo_const(np.frombuffer(opaque, dtype=np.uint8))
opaque_shape = ir.RankedTensorType(opaque_arg.type).shape
return custom_call(
op_name,
result_types=[
ir.RankedTensorType.get(full_output_shape, source_type.element_type)
],
# Reverse points because backend uses Fortran order
operands=[opaque_arg, source, *points[::-1]],
backend_config=opaque,
operand_layouts=default_layouts(
opaque_shape, source_shape, *points_shape[::-1]
),
result_layouts=default_layouts(full_output_shape),
).results

else:
return custom_call(
op_name,
result_types=[
ir.RankedTensorType.get(full_output_shape, source_type.element_type)
],
# Reverse points because backend uses Fortran order
operands=[source, *points[::-1]],
backend_config=opaque,
operand_layouts=default_layouts(source_shape, *points_shape[::-1]),
result_layouts=default_layouts(full_output_shape),
).results
24 changes: 8 additions & 16 deletions src/jax_finufft/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from jax import core
from jax import jit
from jax import numpy as jnp
from jax.interpreters import ad, batching, xla
from jax.interpreters import ad, batching, xla, mlir

from . import shapes, translation
from . import shapes, lowering


@partial(jit, static_argnums=(0,), static_argnames=("iflag", "eps"))
Expand Down Expand Up @@ -181,13 +181,9 @@ def batch(args, axes, *, output_shape, **kwargs):
nufft1_p = core.Primitive("nufft1")
nufft1_p.def_impl(partial(xla.apply_primitive, nufft1_p))
nufft1_p.def_abstract_eval(shapes.abstract_eval)
xla.register_translation(
nufft1_p, partial(translation.translation_rule, "cpu"), platform="cpu"
)
if translation.jax_finufft_gpu is not None:
xla.register_translation(
nufft1_p, partial(translation.translation_rule, "gpu"), platform="cuda"
)
mlir.register_lowering(nufft1_p, partial(lowering.lowering, "cpu"), platform="cpu")
if lowering.jax_finufft_gpu is not None:
mlir.register_lowering(nufft1_p, partial(lowering.lowering, "cpu"), platform="gpu")
ad.primitive_jvps[nufft1_p] = partial(jvp, nufft1_p)
ad.primitive_transposes[nufft1_p] = transpose
batching.primitive_batchers[nufft1_p] = batch
Expand All @@ -196,13 +192,9 @@ def batch(args, axes, *, output_shape, **kwargs):
nufft2_p = core.Primitive("nufft2")
nufft2_p.def_impl(partial(xla.apply_primitive, nufft2_p))
nufft2_p.def_abstract_eval(shapes.abstract_eval)
xla.register_translation(
nufft2_p, partial(translation.translation_rule, "cpu"), platform="cpu"
)
if translation.jax_finufft_gpu is not None:
xla.register_translation(
nufft2_p, partial(translation.translation_rule, "gpu"), platform="cuda"
)
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "cpu"), platform="cpu")
if lowering.jax_finufft_gpu is not None:
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "cpu"), platform="gpu")
ad.primitive_jvps[nufft2_p] = partial(jvp, nufft2_p)
ad.primitive_transposes[nufft2_p] = transpose
batching.primitive_batchers[nufft2_p] = batch
128 changes: 0 additions & 128 deletions src/jax_finufft/translation.py

This file was deleted.

0 comments on commit 20b6e35

Please sign in to comment.