Skip to content

Commit

Permalink
A new PR to generalize the vmap behavior (#12)
Browse files Browse the repository at this point in the history
* non-batched dims

* getting more general vmaps to work
  • Loading branch information
dfm authored Dec 10, 2021
1 parent 6b74877 commit 5c9445d
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 29 deletions.
61 changes: 54 additions & 7 deletions src/jax_finufft/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def nufft1(output_shape, source, *points, iflag=1, eps=1e-6):

# Handle broadcasting
expected_output_shape = source.shape[:-1] + tuple(output_shape)

source, points = pad_shapes(1, source, *points)
if points[0].shape[-1] != source.shape[-1]:
raise ValueError("The final dimension of 'source' must match 'points'")
Expand Down Expand Up @@ -275,14 +276,60 @@ def transpose(type_, doutput, source, *points, output_shape, eps, iflag):


def batch(type_, prim, args, axes, **kwargs):
ndim = len(args) - 1
if type_ == 1:
mx = args[0].ndim - 2
source, *points = args
bsource, *bpoints = axes

# TODO: the following logic doesn't work yet. If none of the points are batched,
# we should be able to get a faster computation by stacking the transforms
# into a single transform and then reshaping. It might be worth making the
# stacked axes into an explicit parameter rather than just trying to infer
# it.

# # If none of the points are being mapped, we can get a faster computation using
# # a single transform with num_transforms * num_repeats
# if all(bx is batching.not_mapped for bx in bpoints):
# assert bsource is not batching.not_mapped

# ndim = len(points)
# in_axis = -2 if type_ == 1 else -ndim - 1
# out_axis = -2 if type_ == 2 else -ndim - 1
# num_repeats = source.shape[bsource]

# source = batching.moveaxis(source, bsource, in_axis)
# source = source.reshape(
# source.shape[: in_axis - 1] + (-1,) + source.shape[in_axis + 1 :]
# )
# result = prim.bind(source, *points, **kwargs)
# return (
# result.reshape(
# result.shape[: out_axis - 1]
# + (-1, source.shape[in_axis])
# + result.shape[out_axis + 1 :]
# ),
# out_axis,
# )

# Otherwise move the batching dimension to the front and repeat the arrays
# to the right shape
if bsource is None:
assert any(bx is not batching.not_mapped for bx in bpoints)
num_repeats = next(
x.shape[bx]
for x, bx in zip(points, bpoints)
if bx is not batching.not_mapped
)
source = jnp.repeat(source[jnp.newaxis], num_repeats, axis=0)
else:
mx = args[0].ndim - ndim - 1
assert all(a < mx for a in axes)
assert all(a == axes[0] for a in axes[1:])
return prim.bind(*args, **kwargs), axes[0]
num_repeats = source.shape[bsource]
source = batching.moveaxis(source, bsource, 0)

mapped_points = []
for x, bx in zip(points, bpoints):
if bx is batching.not_mapped:
mapped_points.append(jnp.repeat(x[jnp.newaxis], num_repeats, axis=0))
else:
mapped_points.append(batching.moveaxis(x, bx, 0))
return prim.bind(source, *mapped_points, **kwargs), 0


def pad_shapes(output_dim, source, *points):
Expand Down
78 changes: 56 additions & 22 deletions tests/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,41 @@ def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag):
dtype = np.double
cdtype = np.cdouble

num_repeat = 5
num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim))

x = [
random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype)
random.uniform(-np.pi, np.pi, size=(num_repeat, num_nonnuniform)).astype(dtype)
for _ in range(ndim)
]
c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform)
c = random.normal(size=(num_repeat, num_nonnuniform)) + 1j * random.normal(
size=(num_repeat, num_nonnuniform)
)
c = c.astype(cdtype)

num = 5
xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x]
cs = jnp.repeat(c[None], num, axis=0)

func = partial(nufft1, num_uniform, iflag=iflag)
calc = jax.vmap(func)(cs, *xs)
expect = func(c, *x)
for n in range(num):
np.testing.assert_allclose(calc[n], expect)

# Start by checking the full basic vmap
calc = jax.vmap(func)(c, *x)
for n in range(num_repeat):
np.testing.assert_allclose(calc[n], func(c[n], *(x_[n] for x_ in x)))

# With different in_axes
calc_ax = jax.vmap(func, in_axes=(1,) + (0,) * ndim)(jnp.moveaxis(c, 0, 1), *x)
np.testing.assert_allclose(calc_ax, calc)

# With unmapped source axis
calc_unmap = jax.vmap(func, in_axes=(None,) + (0,) * ndim)(c[0], *x)
for n in range(num_repeat):
np.testing.assert_allclose(calc_unmap[n], func(c[0], *(x_[n] for x_ in x)))

# With unmapped points axis
calc_unmap_pt = jax.vmap(func, in_axes=(0,) + (0,) * (ndim - 1) + (None,))(
c, *x[:-1], x[-1][0]
)
for n in range(num_repeat):
np.testing.assert_allclose(
calc_unmap_pt[n], func(c[n], *(x_[n] for x_ in x[:-1]), x[-1][0])
)


@pytest.mark.parametrize(
Expand All @@ -175,24 +192,41 @@ def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag):
dtype = np.double
cdtype = np.cdouble

num_repeat = 5
num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim))

x = [
random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype)
random.uniform(-np.pi, np.pi, size=(num_repeat, num_nonnuniform)).astype(dtype)
for _ in range(ndim)
]
f = random.normal(size=num_uniform) + 1j * random.normal(size=num_uniform)
f = random.normal(size=(num_repeat,) + num_uniform) + 1j * random.normal(
size=(num_repeat,) + num_uniform
)
f = f.astype(cdtype)

num = 5
xs = [jnp.repeat(x_[jnp.newaxis], num, axis=0) for x_ in x]
fs = jnp.repeat(f[jnp.newaxis], num, axis=0)

func = partial(nufft2, iflag=iflag)
calc = jax.vmap(func)(fs, *xs)
expect = func(f, *x)
for n in range(num):
np.testing.assert_allclose(calc[n], expect)

# Start by checking the full basic vmap
calc = jax.vmap(func)(f, *x)
for n in range(num_repeat):
np.testing.assert_allclose(calc[n], func(f[n], *(x_[n] for x_ in x)))

# With different in_axes
calc_ax = jax.vmap(func, in_axes=(1,) + (0,) * ndim)(jnp.moveaxis(f, 0, 1), *x)
np.testing.assert_allclose(calc_ax, calc)

# With unmapped source axis
calc_unmap = jax.vmap(func, in_axes=(None,) + (0,) * ndim)(f[0], *x)
for n in range(num_repeat):
np.testing.assert_allclose(calc_unmap[n], func(f[0], *(x_[n] for x_ in x)))

# With unmapped points axis
calc_unmap_pt = jax.vmap(func, in_axes=(0,) + (0,) * (ndim - 1) + (None,))(
f, *x[:-1], x[-1][0]
)
for n in range(num_repeat):
np.testing.assert_allclose(
calc_unmap_pt[n], func(f[n], *(x_[n] for x_ in x[:-1]), x[-1][0])
)


def test_multi_transform():
Expand Down

0 comments on commit 5c9445d

Please sign in to comment.