diff --git a/README.md b/README.md index 8fb6a21..dc6f6c0 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ c = nufft2(f, x, y, z) # 3D This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright: -Copyright 2021 The Simons Foundation, Inc. +Copyright 2021, 2022, 2023 The Simons Foundation, Inc. If you use this software, please cite the primary references listed on the [FINUFFT docs](https://finufft.readthedocs.io/en/latest/refs.html). diff --git a/src/jax_finufft/__init__.py b/src/jax_finufft/__init__.py index 931e67b..57997be 100644 --- a/src/jax_finufft/__init__.py +++ b/src/jax_finufft/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The Simons Foundation, Inc. +# Copyright 2021, 2022, 2023 The Simons Foundation, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/jax_finufft/lowering.py b/src/jax_finufft/lowering.py new file mode 100644 index 0000000..87117c5 --- /dev/null +++ b/src/jax_finufft/lowering.py @@ -0,0 +1,107 @@ +import numpy as np +from jax.interpreters.mlir import ir +from jax.lib import xla_client +from jaxlib.hlo_helpers import custom_call, hlo_const + +from . import jax_finufft_cpu + +try: + from . import jax_finufft_gpu + + for _name, _value in jax_finufft_gpu.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="gpu") +except ImportError: + jax_finufft_gpu = None + +for _name, _value in jax_finufft_cpu.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="cpu") + + +def default_layouts(*shapes): + return [range(len(shape) - 1, -1, -1) for shape in shapes] + + +def lowering(platform, ctx, source, *points, output_shape, iflag, eps): + del ctx + + if platform not in ["cpu", "gpu"]: + raise ValueError(f"Unrecognized platform '{platform}'") + + if platform == "gpu" and jax_finufft_gpu is None: + raise ValueError("jax-finufft was not compiled with GPU support") + + ndim = len(points) + assert 1 <= ndim <= 3 + if platform == "gpu" and ndim == 1: + raise ValueError("1-D transforms are not yet supported on the GPU") + + source_type = ir.RankedTensorType(source.type) + points_type = [ir.RankedTensorType(x.type) for x in points] + + # Check supported and consistent dtypes + f32 = ir.F32Type.get() + f64 = ir.F64Type.get() + source_dtype = source_type.element_type + single = source_dtype == ir.ComplexType.get(f32) and all( + x.element_type == f32 for x in points_type + ) + double = source_dtype == ir.ComplexType.get(f64) and all( + x.element_type == f64 for x in points_type + ) + assert single or double + suffix = "f" if single else "" + + # Check shapes + source_shape = source_type.shape + points_shape = tuple(x.shape for x in points_type) + n_tot = source_shape[0] + n_transf = source_shape[1] + n_j = points_shape[0][1] + if output_shape is None: + op_name = f"nufft{ndim}d2{suffix}".encode("ascii") + n_k = np.array(source_shape[2:], dtype=np.int64) + full_output_shape = tuple(source_shape[:2]) + (n_j,) + else: + op_name = f"nufft{ndim}d1{suffix}".encode("ascii") + n_k = np.array(output_shape, dtype=np.int64) + full_output_shape = tuple(source_shape[:2]) + tuple(output_shape) + + # The backend expects the output shape in Fortran order so we'll just + # fake it here, by sending in n_k and x in the reverse order. + n_k_full = np.zeros(3, dtype=np.int64) + n_k_full[:ndim] = n_k[::-1] + + # Build the descriptor containing the transform parameters + opaque = getattr(jax_finufft_cpu, f"build_descriptor{suffix}")( + eps, iflag, n_tot, n_transf, n_j, *n_k_full + ) + + if platform == "cpu": + opaque_arg = hlo_const(np.frombuffer(opaque, dtype=np.uint8)) + opaque_shape = ir.RankedTensorType(opaque_arg.type).shape + return custom_call( + op_name, + result_types=[ + ir.RankedTensorType.get(full_output_shape, source_type.element_type) + ], + # Reverse points because backend uses Fortran order + operands=[opaque_arg, source, *points[::-1]], + backend_config=opaque, + operand_layouts=default_layouts( + opaque_shape, source_shape, *points_shape[::-1] + ), + result_layouts=default_layouts(full_output_shape), + ).results + + else: + return custom_call( + op_name, + result_types=[ + ir.RankedTensorType.get(full_output_shape, source_type.element_type) + ], + # Reverse points because backend uses Fortran order + operands=[source, *points[::-1]], + backend_config=opaque, + operand_layouts=default_layouts(source_shape, *points_shape[::-1]), + result_layouts=default_layouts(full_output_shape), + ).results diff --git a/src/jax_finufft/ops.py b/src/jax_finufft/ops.py index a3de783..f6db90f 100644 --- a/src/jax_finufft/ops.py +++ b/src/jax_finufft/ops.py @@ -6,9 +6,9 @@ from jax import core from jax import jit from jax import numpy as jnp -from jax.interpreters import ad, batching, xla +from jax.interpreters import ad, batching, xla, mlir -from . import shapes, translation +from . import shapes, lowering @partial(jit, static_argnums=(0,), static_argnames=("iflag", "eps")) @@ -181,13 +181,9 @@ def batch(args, axes, *, output_shape, **kwargs): nufft1_p = core.Primitive("nufft1") nufft1_p.def_impl(partial(xla.apply_primitive, nufft1_p)) nufft1_p.def_abstract_eval(shapes.abstract_eval) -xla.register_translation( - nufft1_p, partial(translation.translation_rule, "cpu"), platform="cpu" -) -if translation.jax_finufft_gpu is not None: - xla.register_translation( - nufft1_p, partial(translation.translation_rule, "gpu"), platform="cuda" - ) +mlir.register_lowering(nufft1_p, partial(lowering.lowering, "cpu"), platform="cpu") +if lowering.jax_finufft_gpu is not None: + mlir.register_lowering(nufft1_p, partial(lowering.lowering, "gpu"), platform="gpu") ad.primitive_jvps[nufft1_p] = partial(jvp, nufft1_p) ad.primitive_transposes[nufft1_p] = transpose batching.primitive_batchers[nufft1_p] = batch @@ -196,13 +192,9 @@ def batch(args, axes, *, output_shape, **kwargs): nufft2_p = core.Primitive("nufft2") nufft2_p.def_impl(partial(xla.apply_primitive, nufft2_p)) nufft2_p.def_abstract_eval(shapes.abstract_eval) -xla.register_translation( - nufft2_p, partial(translation.translation_rule, "cpu"), platform="cpu" -) -if translation.jax_finufft_gpu is not None: - xla.register_translation( - nufft2_p, partial(translation.translation_rule, "gpu"), platform="cuda" - ) +mlir.register_lowering(nufft2_p, partial(lowering.lowering, "cpu"), platform="cpu") +if lowering.jax_finufft_gpu is not None: + mlir.register_lowering(nufft2_p, partial(lowering.lowering, "gpu"), platform="gpu") ad.primitive_jvps[nufft2_p] = partial(jvp, nufft2_p) ad.primitive_transposes[nufft2_p] = transpose batching.primitive_batchers[nufft2_p] = batch diff --git a/src/jax_finufft/translation.py b/src/jax_finufft/translation.py deleted file mode 100644 index 8277f11..0000000 --- a/src/jax_finufft/translation.py +++ /dev/null @@ -1,128 +0,0 @@ -__all__ = ["translation_rule"] - -import numpy as np -from jax.lib import xla_client - -from . import jax_finufft_cpu - -try: - from . import jax_finufft_gpu - - for _name, _value in jax_finufft_gpu.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") -except ImportError: - jax_finufft_gpu = None - -for _name, _value in jax_finufft_cpu.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="cpu") - -xops = xla_client.ops - - -def translation_rule( - platform, ctx, avals_in, avals_out, source, *points, output_shape, iflag, eps -): - if platform == "gpu" and jax_finufft_gpu is None: - raise ValueError("jax-finufft was not compiled with GPU support") - - ndim = len(points) - assert 1 <= ndim <= 3 - if platform == "gpu" and ndim == 1: - raise ValueError("1-D transforms are not yet supported on the GPU") - - c = ctx.builder - source_shape_info = c.get_shape(source) - points_shape_info = list(map(c.get_shape, points)) - - # Check supported and consistent dtypes - source_dtype = source_shape_info.element_type() - single = source_dtype == np.csingle and all( - x.element_type() == np.single for x in points_shape_info - ) - double = source_dtype == np.cdouble and all( - x.element_type() == np.double for x in points_shape_info - ) - assert single or double - suffix = "f" if source_dtype == np.csingle else "" - - # Check shapes - source_shape = source_shape_info.dimensions() - points_shape = tuple(x.dimensions() for x in points_shape_info) - n_tot = source_shape[0] - n_transf = source_shape[1] - n_j = points_shape[0][1] - if output_shape is None: - op_name = f"nufft{ndim}d2{suffix}".encode("ascii") - n_k = np.array(source_shape[2:], dtype=np.int64) - full_output_shape = source_shape[:2] + (n_j,) - else: - op_name = f"nufft{ndim}d1{suffix}".encode("ascii") - n_k = np.array(output_shape, dtype=np.int64) - full_output_shape = source_shape[:2] + tuple(output_shape) - - # The backend expects the output shape in Fortran order so we'll just - # fake it here, by sending in n_k and x in the reverse order. - n_k_full = np.zeros(3, dtype=np.int64) - n_k_full[:ndim] = n_k[::-1] - - # Dispatch to the right op - desc = getattr(jax_finufft_cpu, f"build_descriptor{suffix}")( - eps, iflag, n_tot, n_transf, n_j, *n_k_full - ) - - # Set up most of the arguments - operands = ( - source, - *points[::-1], # Reverse order because backend uses Fortran order - ) - operand_shapes_with_layout = ( - xla_client.Shape.array_shape( - source_dtype, - source_shape, - tuple(range(len(source_shape) - 1, -1, -1)), - ), - ) + tuple( - xla_client.Shape.array_shape( - x.element_type(), - x.dimensions(), - tuple(range(len(x.dimensions()) - 1, -1, -1)), - ) - for x in points_shape_info[::-1] # Reverse order, again - ) - shape_with_layout = xla_client.Shape.array_shape( - source_dtype, - full_output_shape, - tuple(range(len(full_output_shape) - 1, -1, -1)), - ) - - if platform == "cpu": - return [ - xops.CustomCallWithLayout( - c, - op_name, - operands=(xops.ConstantLiteral(c, np.frombuffer(desc, dtype=np.uint8)),) - + operands, - operand_shapes_with_layout=( - xla_client.Shape.array_shape( - np.dtype(np.uint8), (len(desc),), (0,) - ), - ) - + operand_shapes_with_layout, - shape_with_layout=shape_with_layout, - ) - ] - - elif platform == "gpu": - return [ - xops.CustomCallWithLayout( - c, - op_name, - operands=operands, - operand_shapes_with_layout=operand_shapes_with_layout, - shape_with_layout=shape_with_layout, - opaque=desc, - ) - ] - - else: - raise ValueError(f"Unrecognized platform '{platform}'")