Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent custom calls with side effects to be optimized out #3829

Closed
dionhaefner opened this issue Jul 23, 2020 · 22 comments
Closed

Prevent custom calls with side effects to be optimized out #3829

dionhaefner opened this issue Jul 23, 2020 · 22 comments
Labels
enhancement New feature or request

Comments

@dionhaefner
Copy link

I am currently experimenting with implementing MPI send / recv as custom XLA calls.

It works fine in most cases, but a function like this leads to a deadlock:

@jax.jit
def send_recv(x):
    if rank == 0:
        x = Recv(x, comm=comm)
    else:
        Send(x, 0, comm=comm)
        # works if doing x = Send(x, 0, comm=comm)
    return x

I guess this is because the return value of Send is not used in the computational graph, so the whole call is optimized away, despite having side effects.

Is there a way to prevent this?

@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2020

I think the bigger issue is that jit can't stage out MPI calls, not only for tracing reasons but more importantly because those calls don't exist in XLA.

The clearest path forward would be to jit smaller functions, and leave the MPI calls outside the jitted functions so they execute in Python. Is that an option?

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Jul 23, 2020

@MattJ Nope, that's not a problem.

I managed to write a small package defining xla custom ops in Cython for MPI operations.
At the moment the only supported one is Allreduce, which works very well.

We'd like to implement also other stuff like send/recv, as @dionhaefner pointed out above.

@PhilipVinc
Copy link
Contributor

(The rationale behind this is that in some complicated distributed batched MatMul code, I see remarkable speed-ups with this approach compared to jitting only part of the functions. Also because if you want to AD or run those functions through a linear optimiser like jag's CG you have no choice but to jit the whole function)

@hawkinsp
Copy link
Collaborator

It actually turns out the XLA already supports this but we haven't plumbed it through as part of our Python bindings. It should be an easy fix, but will require a jaxlib rebuild.

@PhilipVinc
Copy link
Contributor

Is it something you'd consider doing?
Or if not, could you point us in the good direction? I cannot find any mention of this in xla's documentation.

@hawkinsp
Copy link
Collaborator

I'm preparing a change to add this now.

@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2020

Wow, awesome work on mpi4jax! Thanks for clarifying; I missed that point in @dionhaefner 's original message, though I see it now.

@hawkinsp in addition to the XLA update, we still have a tracing issue to fix, right? That is, we basically need #3370 to land, or else some other way to ensure the Send is actually staged out from JAX.

@hawkinsp
Copy link
Collaborator

We don't necessarily need #3370; if nothing else we can use XLA tokens here.

@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2020

Right, we need to fix the tracing issue, which could be #3370 or could be adding tokens to the code. But to support the code as written in the OP, i.e. without tokens, we need #3370.

tensorflow-copybara pushed a commit to tensorflow/tensorflow that referenced this issue Jul 23, 2020
Will fix jax-ml/jax#3829 when incorporated into a jaxlib release.

PiperOrigin-RevId: 322787456
Change-Id: If2ade6a15875c476d0e160b6ef17a4fb0b2d37fe
@PhilipVinc
Copy link
Contributor

PhilipVinc commented Jul 23, 2020

I'm sorry but I'm not sure I follow.
What are tokens? What is the issue you are referring to?

Because for all reduce (so an operation that always has an output) our approach already works.
The issue here is that operations with no (local) side effect such as send can be optimised out of the IR.

@mattjj mattjj added the enhancement New feature or request label Jul 23, 2020
@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2020

What are tokens? What is the issue you are referring to?
The issue here is that operations with no (local) side effect such as send can be optimised out of the IR.

There are essentially two places that dead code elimination (DCE) may happen, leading to pruning of operations that (as far as the system currently understands) have no effect on the result of the computation. One is the Python tracing mechanism and has nothing to do with XLA or jaxlib:

In [1]: def f(x):
   ...:     y = x + x
   ...:     return x + 1
   ...:

In [2]: from jax import make_jaxpr

In [3]: make_jaxpr(f)(2)
Out[3]:
{ lambda  ; a.
  let b = add a 1
  in (b,) }

The other is XLA, which will prune operations unless (a) the value of the computation has a data dependence on the operation, or (b) the operation is marked as side-effectful.

I believe the fix @hawkinsp alluded to is about the latter, essentially adding a way to label XLA CustomCalls as side-effecting so that XLA doesn't prune them.

My point above is that we're still left with the former, i.e. the JAX tracing issue. The way to solve that on the current master branch is to add 'token' values which we thread into and out of side-effecting operations and on which the final result has a fake data dependence; there are more details to unpack you can see an example in the infeed tests. But on the #3370 branch we don't need tokens for the JAX side anymore, as you can see on that branch's version of infeed_test.py. This all has to do with the JAX Python side; in any case we needed the fix Peter added in XLA.

Does that make sense?

@hawkinsp hawkinsp mentioned this issue Jul 23, 2020
@hawkinsp
Copy link
Collaborator

#3834 adds a new has_side_effects argument to CustomCall, which handles the XLA side of this.

You can also work around both the XLA and JAX-side issues by adding a dummy output to your operator (e.g, make it return a a scalar that you return.)

@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2020

By the way, tokens are also useful for sequencing things. That is, JAX jit tracing may reorder operations when there is no data dependence between them (until #3370, which preserves Python execution order), and similarly XLA may reorder operations with no data dependence (one must use tokens for that AIUI).

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Jul 23, 2020 via email

@hawkinsp
Copy link
Collaborator

We just made one yesterday :-( I suggest building from source for now?

@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2020

I think that’s all that is needed.
We already define a custom Jax primitive, and when you call ‘Send’ it is binded to the input.

Actually no, jit tracing will drop primitives that are bound if the output of the jitted computation doesn't have a data dependence on the result of the primitive. That's what my example above with make_jaxpr was meant to show: even though we call two adds, we only build a jaxpr (and then an XLA computation) containing the one that affected the output.

Concretely, the code in the OP will still drop the Send call. It'll get dropped on the JAX side, before XLA even has a chance to see it.

@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2020

Here's a more self-contained example if you prefer:

from jax.core import Primitive

def Send(x):
  return send_p.bind(x)

send_p = Primitive('send')
send_p.def_abstract_eval(lambda x: None)


def f(x):
  Send(x)

from jax import make_jaxpr
print(make_jaxpr(f)(2))
{ lambda  ; a.
  let
  in () }

You can change the make_jaxpr to a jit and add a print(built.as_hlo_text()) after this line in xla.py if you want to convince yourself that XLA will never see the bound primitive.

@dionhaefner
Copy link
Author

Thanks a bunch for the explanation and the quick fix! I will try this ASAP and report back.

@dionhaefner
Copy link
Author

OK, so it seems like there is currently no way to preserve the Send call in nested JIT calls?

I.e.

def Send(x, dest, tag=0, comm=_MPI.COMM_WORLD):
    token = lax.create_token(x)
    token = _Send(token, x, dest, tag, comm)
    return lax.tie_in(token, x)

def Send_nested(x, dest, tag=0, comm=_MPI.COMM_WORLD):
    Send(x, dest, tag, comm)

print(jax.make_jaxpr(Send)(jnp.zeros(1), 0))
print(jax.make_jaxpr(Send_nested)(jnp.zeros(1), 0))

gives

{ lambda  ; a b.
  let c = create_token a
      d = send_mpi[ comm=<mpi4py.MPI.Intracomm object at 0x10928e5d0>
                    dest=Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=0/0)>
                    tag=0 ] c a
      e = tie_in d a
  in (e,) }

{ lambda  ; a b.
  let
  in () }

So all user code that uses our Send implementation would have to include some boilerplate with create_token and tie_in.

@mattjj
Copy link
Collaborator

mattjj commented Jul 24, 2020

Yes, that's right for now, but with #3370 neither tokens nor tie_in will be necessary.

@dionhaefner
Copy link
Author

Looking forward to that! Do you have any estimate when omnistaging is going to hit master?

@dionhaefner
Copy link
Author

This works beautifully with omnistaging, thanks!

I can also confirm that it does not work without doing has_side_effect=True, so both changes were necessary. Great job @mattjj and @hawkinsp!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants