Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing deprecated XLA translation rules with MLIR lowering rules #21

Merged
merged 3 commits into from
Oct 31, 2023
Merged

Conversation

dfm
Copy link
Collaborator

@dfm dfm commented Oct 30, 2023

The old interface used to define XLA translation rules have been deprecated in favor of these MLIR lowering rules. This PR replaces the old implementation.

The pre-commit rules were also slightly updated and run on the full code base, so this is a slightly larger diff than I was planning.

@dfm
Copy link
Collaborator Author

dfm commented Oct 30, 2023

@lgarrison — I don't have a mechanism for testing the GPU ops. Would you be able to run a quick check?

@lgarrison
Copy link
Member

No problem! Unfortunately it looks like something did break on the GPU side. Here's the first failing test:

(venv-cu11) scclin021:~/jax-finufft$ python -m pytest -xv tests/
============================================================== test session starts ===============================================================
platform linux -- Python 3.11.2, pytest-7.4.2, pluggy-1.3.0 -- /mnt/home/lgarrison/jax-finufft/venv-cu11/bin/python
cachedir: .pytest_cache
rootdir: /mnt/home/lgarrison/jax-finufft
collected 53 items                                                                                                                               

tests/ops_test.py::test_nufft1_forward[1-False-50-75--1] SKIPPED (1D transforms not implemented on GPU)                                    [  1%]
tests/ops_test.py::test_nufft1_forward[1-False-50-75-1] SKIPPED (1D transforms not implemented on GPU)                                     [  3%]
tests/ops_test.py::test_nufft1_forward[1-True-50-75--1] SKIPPED (1D transforms not implemented on GPU)                                     [  5%]
tests/ops_test.py::test_nufft1_forward[1-True-50-75-1] SKIPPED (1D transforms not implemented on GPU)                                      [  7%]
tests/ops_test.py::test_nufft1_forward[2-False-50-75--1] FAILED                                                                            [  9%]

==================================================================== FAILURES ====================================================================
_____________________________________________________ test_nufft1_forward[2-False-50-75--1] ______________________________________________________

ndim = 2, x64 = False, num_nonnuniform = 50, num_uniform = (37, 42), iflag = -1

    @pytest.mark.parametrize(
        "ndim, x64, num_nonnuniform, num_uniform, iflag",
        product([1, 2, 3], [False, True], [50], [75], [-1, 1]),
    )
    def test_nufft1_forward(ndim, x64, num_nonnuniform, num_uniform, iflag):
        if ndim == 1 and jax.default_backend() != "cpu":
            pytest.skip("1D transforms not implemented on GPU")
    
        random = np.random.default_rng(657)
    
        eps = 1e-10 if x64 else 1e-7
        dtype = np.double if x64 else np.single
        cdtype = np.cdouble if x64 else np.csingle
    
        num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim))
        ks = [np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) for n in num_uniform]
    
        x = random.uniform(-np.pi, np.pi, size=(ndim, num_nonnuniform)).astype(dtype)
        c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform)
        c = c.astype(cdtype)
        f_expect = np.zeros(num_uniform, dtype=cdtype)
        for coords in product(*map(range, num_uniform)):
            k_vec = np.array([k[n] for (n, k) in zip(coords, ks)])
            f_expect[coords] = np.sum(c * np.exp(iflag * 1j * np.dot(k_vec, x)))
    
        with jax.experimental.enable_x64(x64):
            f_calc = nufft1(num_uniform, c, *x, eps=eps, iflag=iflag)
>           np.testing.assert_allclose(f_calc, f_expect, rtol=5e-7 if x64 else 5e-2)

tests/ops_test.py:41: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0x7f636eb7b7e0>, array([[nan+nanj, nan+nanj, nan+nanj, ..., nan+nanj, n...-3.0006557 -5.3153861e-01j,
        -10.215852  +8.2104796e-01j,  -2.1991322 -9.8890707e-02j]],
      dtype=complex64))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0.05, atol=0', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=0.05, atol=0
E           
E           x and y nan location mismatch:
E            x: array([[nan+nanj, nan+nanj, nan+nanj, ..., nan+nanj, nan+nanj, nan+nanj],
E                  [nan+nanj, nan+nanj, nan+nanj, ..., nan+nanj, nan+nanj, nan+nanj],
E                  [nan+nanj, nan+nanj, nan+nanj, ..., nan+nanj, nan+nanj, nan+nanj],...
E            y: array([[ -0.873955-3.174950e+00j,  -0.80325 +4.591350e-03j,
E                    -3.518164+5.852695e+00j, ...,  -9.302967-3.606844e-01j,
E                    -4.286073+5.104557e+00j,   2.473569-6.756186e+00j],...

/mnt/sw/nix/store/wignb2nj7xdjs8y6307gcb18jr5lcqpm-python-3.11.2-view/lib/python3.11/contextlib.py:81: AssertionError
-------------------------------------------------------------- Captured stderr call --------------------------------------------------------------
setup_spreader: warning, increasing tol=1e-07 to eps_mach=1.19e-07.
============================================================ short test summary info =============================================================
FAILED tests/ops_test.py::test_nufft1_forward[2-False-50-75--1] - AssertionError: 
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
========================================================== 1 failed, 4 skipped in 2.43s ==========================================================

None of the GPU tests pass; all of the CPU ones do. Something nasty must be happening because the GPU tests eventually crash with what I think is a bad address on the GPU side.

@dfm dfm mentioned this pull request Oct 30, 2023
@dfm
Copy link
Collaborator Author

dfm commented Oct 30, 2023

OK - thanks! I will look into what's happening here.

@@ -194,7 +194,7 @@ def batch(args, axes, *, output_shape, **kwargs):
nufft2_p.def_abstract_eval(shapes.abstract_eval)
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "cpu"), platform="cpu")
if lowering.jax_finufft_gpu is not None:
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "cpu"), platform="gpu")
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "gpu"), platform="gpu")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lgarrison — Doh! I think I found the issue lol

@lgarrison
Copy link
Member

Looks like that was it, the GPU tests pass now!

@dfm dfm merged commit ec080e4 into main Oct 31, 2023
3 checks passed
@dfm dfm deleted the mlir branch October 31, 2023 01:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants