From 5c9445d77e366e79432a5c24eced84048dde5f75 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 10 Dec 2021 14:24:39 -0500 Subject: [PATCH] A new PR to generalize the vmap behavior (#12) * non-batched dims * getting more general vmaps to work --- src/jax_finufft/ops.py | 61 +++++++++++++++++++++++++++++---- tests/ops_test.py | 78 ++++++++++++++++++++++++++++++------------ 2 files changed, 110 insertions(+), 29 deletions(-) diff --git a/src/jax_finufft/ops.py b/src/jax_finufft/ops.py index 87dc832..dddcc1e 100644 --- a/src/jax_finufft/ops.py +++ b/src/jax_finufft/ops.py @@ -32,6 +32,7 @@ def nufft1(output_shape, source, *points, iflag=1, eps=1e-6): # Handle broadcasting expected_output_shape = source.shape[:-1] + tuple(output_shape) + source, points = pad_shapes(1, source, *points) if points[0].shape[-1] != source.shape[-1]: raise ValueError("The final dimension of 'source' must match 'points'") @@ -275,14 +276,60 @@ def transpose(type_, doutput, source, *points, output_shape, eps, iflag): def batch(type_, prim, args, axes, **kwargs): - ndim = len(args) - 1 - if type_ == 1: - mx = args[0].ndim - 2 + source, *points = args + bsource, *bpoints = axes + + # TODO: the following logic doesn't work yet. If none of the points are batched, + # we should be able to get a faster computation by stacking the transforms + # into a single transform and then reshaping. It might be worth making the + # stacked axes into an explicit parameter rather than just trying to infer + # it. + + # # If none of the points are being mapped, we can get a faster computation using + # # a single transform with num_transforms * num_repeats + # if all(bx is batching.not_mapped for bx in bpoints): + # assert bsource is not batching.not_mapped + + # ndim = len(points) + # in_axis = -2 if type_ == 1 else -ndim - 1 + # out_axis = -2 if type_ == 2 else -ndim - 1 + # num_repeats = source.shape[bsource] + + # source = batching.moveaxis(source, bsource, in_axis) + # source = source.reshape( + # source.shape[: in_axis - 1] + (-1,) + source.shape[in_axis + 1 :] + # ) + # result = prim.bind(source, *points, **kwargs) + # return ( + # result.reshape( + # result.shape[: out_axis - 1] + # + (-1, source.shape[in_axis]) + # + result.shape[out_axis + 1 :] + # ), + # out_axis, + # ) + + # Otherwise move the batching dimension to the front and repeat the arrays + # to the right shape + if bsource is None: + assert any(bx is not batching.not_mapped for bx in bpoints) + num_repeats = next( + x.shape[bx] + for x, bx in zip(points, bpoints) + if bx is not batching.not_mapped + ) + source = jnp.repeat(source[jnp.newaxis], num_repeats, axis=0) else: - mx = args[0].ndim - ndim - 1 - assert all(a < mx for a in axes) - assert all(a == axes[0] for a in axes[1:]) - return prim.bind(*args, **kwargs), axes[0] + num_repeats = source.shape[bsource] + source = batching.moveaxis(source, bsource, 0) + + mapped_points = [] + for x, bx in zip(points, bpoints): + if bx is batching.not_mapped: + mapped_points.append(jnp.repeat(x[jnp.newaxis], num_repeats, axis=0)) + else: + mapped_points.append(batching.moveaxis(x, bx, 0)) + return prim.bind(source, *mapped_points, **kwargs), 0 def pad_shapes(output_dim, source, *points): diff --git a/tests/ops_test.py b/tests/ops_test.py index 4ba73b0..fd25863 100644 --- a/tests/ops_test.py +++ b/tests/ops_test.py @@ -145,24 +145,41 @@ def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag): dtype = np.double cdtype = np.cdouble + num_repeat = 5 num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) x = [ - random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) + random.uniform(-np.pi, np.pi, size=(num_repeat, num_nonnuniform)).astype(dtype) for _ in range(ndim) ] - c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform) + c = random.normal(size=(num_repeat, num_nonnuniform)) + 1j * random.normal( + size=(num_repeat, num_nonnuniform) + ) c = c.astype(cdtype) - - num = 5 - xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x] - cs = jnp.repeat(c[None], num, axis=0) - func = partial(nufft1, num_uniform, iflag=iflag) - calc = jax.vmap(func)(cs, *xs) - expect = func(c, *x) - for n in range(num): - np.testing.assert_allclose(calc[n], expect) + + # Start by checking the full basic vmap + calc = jax.vmap(func)(c, *x) + for n in range(num_repeat): + np.testing.assert_allclose(calc[n], func(c[n], *(x_[n] for x_ in x))) + + # With different in_axes + calc_ax = jax.vmap(func, in_axes=(1,) + (0,) * ndim)(jnp.moveaxis(c, 0, 1), *x) + np.testing.assert_allclose(calc_ax, calc) + + # With unmapped source axis + calc_unmap = jax.vmap(func, in_axes=(None,) + (0,) * ndim)(c[0], *x) + for n in range(num_repeat): + np.testing.assert_allclose(calc_unmap[n], func(c[0], *(x_[n] for x_ in x))) + + # With unmapped points axis + calc_unmap_pt = jax.vmap(func, in_axes=(0,) + (0,) * (ndim - 1) + (None,))( + c, *x[:-1], x[-1][0] + ) + for n in range(num_repeat): + np.testing.assert_allclose( + calc_unmap_pt[n], func(c[n], *(x_[n] for x_ in x[:-1]), x[-1][0]) + ) @pytest.mark.parametrize( @@ -175,24 +192,41 @@ def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag): dtype = np.double cdtype = np.cdouble + num_repeat = 5 num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim)) x = [ - random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype) + random.uniform(-np.pi, np.pi, size=(num_repeat, num_nonnuniform)).astype(dtype) for _ in range(ndim) ] - f = random.normal(size=num_uniform) + 1j * random.normal(size=num_uniform) + f = random.normal(size=(num_repeat,) + num_uniform) + 1j * random.normal( + size=(num_repeat,) + num_uniform + ) f = f.astype(cdtype) - - num = 5 - xs = [jnp.repeat(x_[jnp.newaxis], num, axis=0) for x_ in x] - fs = jnp.repeat(f[jnp.newaxis], num, axis=0) - func = partial(nufft2, iflag=iflag) - calc = jax.vmap(func)(fs, *xs) - expect = func(f, *x) - for n in range(num): - np.testing.assert_allclose(calc[n], expect) + + # Start by checking the full basic vmap + calc = jax.vmap(func)(f, *x) + for n in range(num_repeat): + np.testing.assert_allclose(calc[n], func(f[n], *(x_[n] for x_ in x))) + + # With different in_axes + calc_ax = jax.vmap(func, in_axes=(1,) + (0,) * ndim)(jnp.moveaxis(f, 0, 1), *x) + np.testing.assert_allclose(calc_ax, calc) + + # With unmapped source axis + calc_unmap = jax.vmap(func, in_axes=(None,) + (0,) * ndim)(f[0], *x) + for n in range(num_repeat): + np.testing.assert_allclose(calc_unmap[n], func(f[0], *(x_[n] for x_ in x))) + + # With unmapped points axis + calc_unmap_pt = jax.vmap(func, in_axes=(0,) + (0,) * (ndim - 1) + (None,))( + f, *x[:-1], x[-1][0] + ) + for n in range(num_repeat): + np.testing.assert_allclose( + calc_unmap_pt[n], func(f[n], *(x_[n] for x_ in x[:-1]), x[-1][0]) + ) def test_multi_transform():