-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prevent custom calls with side effects to be optimized out #3829
Comments
I think the bigger issue is that The clearest path forward would be to |
@MattJ Nope, that's not a problem. I managed to write a small package defining xla custom ops in Cython for MPI operations. We'd like to implement also other stuff like send/recv, as @dionhaefner pointed out above. |
(The rationale behind this is that in some complicated distributed batched MatMul code, I see remarkable speed-ups with this approach compared to jitting only part of the functions. Also because if you want to AD or run those functions through a linear optimiser like jag's CG you have no choice but to jit the whole function) |
It actually turns out the XLA already supports this but we haven't plumbed it through as part of our Python bindings. It should be an easy fix, but will require a |
Is it something you'd consider doing? |
I'm preparing a change to add this now. |
Wow, awesome work on mpi4jax! Thanks for clarifying; I missed that point in @dionhaefner 's original message, though I see it now. @hawkinsp in addition to the XLA update, we still have a tracing issue to fix, right? That is, we basically need #3370 to land, or else some other way to ensure the |
We don't necessarily need #3370; if nothing else we can use XLA tokens here. |
Will fix jax-ml/jax#3829 when incorporated into a jaxlib release. PiperOrigin-RevId: 322787456 Change-Id: If2ade6a15875c476d0e160b6ef17a4fb0b2d37fe
I'm sorry but I'm not sure I follow. Because for all reduce (so an operation that always has an output) our approach already works. |
There are essentially two places that dead code elimination (DCE) may happen, leading to pruning of operations that (as far as the system currently understands) have no effect on the result of the computation. One is the Python tracing mechanism and has nothing to do with XLA or jaxlib:
The other is XLA, which will prune operations unless (a) the value of the computation has a data dependence on the operation, or (b) the operation is marked as side-effectful. I believe the fix @hawkinsp alluded to is about the latter, essentially adding a way to label XLA CustomCalls as side-effecting so that XLA doesn't prune them. My point above is that we're still left with the former, i.e. the JAX tracing issue. The way to solve that on the current master branch is to add 'token' values which we thread into and out of side-effecting operations and on which the final result has a fake data dependence; there are more details to unpack you can see an example in the Does that make sense? |
#3834 adds a new You can also work around both the XLA and JAX-side issues by adding a dummy output to your operator (e.g, make it return a a scalar that you return.) |
By the way, tokens are also useful for sequencing things. That is, JAX |
I think that’s all that is needed.
We already define a custom Jax primitive, and when you call ‘Send’ it is binded to the input. The only problem was that this primitive has no output (or we don’t use it) so cal was removing it...
Thanks a lot!
Any chance we can get a new release of Jaxlib somewhat soon?
…--
Filippo Vicentini
CCQ Research Fellow
Flatiron Institute, New York
Google Scholar
Il 23 lug 2020, 18:49 +0200, Matthew Johnson <notifications@github.com>, ha scritto:
By the way, tokens are also useful for sequencing things. That is, JAX jit tracing may reorder operations when there is no data dependence between them (until #3370, which preserves Python execution order), and similarly XLA may reorder operations with no data dependence (one must use tokens for that AIUI).
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub, or unsubscribe.
|
We just made one yesterday :-( I suggest building from source for now? |
Actually no, Concretely, the code in the OP will still drop the |
Here's a more self-contained example if you prefer: from jax.core import Primitive
def Send(x):
return send_p.bind(x)
send_p = Primitive('send')
send_p.def_abstract_eval(lambda x: None)
def f(x):
Send(x)
from jax import make_jaxpr
print(make_jaxpr(f)(2))
You can change the |
Thanks a bunch for the explanation and the quick fix! I will try this ASAP and report back. |
OK, so it seems like there is currently no way to preserve the Send call in nested JIT calls? I.e. def Send(x, dest, tag=0, comm=_MPI.COMM_WORLD):
token = lax.create_token(x)
token = _Send(token, x, dest, tag, comm)
return lax.tie_in(token, x)
def Send_nested(x, dest, tag=0, comm=_MPI.COMM_WORLD):
Send(x, dest, tag, comm)
print(jax.make_jaxpr(Send)(jnp.zeros(1), 0))
print(jax.make_jaxpr(Send_nested)(jnp.zeros(1), 0)) gives
So all user code that uses our |
Yes, that's right for now, but with #3370 neither tokens nor |
Looking forward to that! Do you have any estimate when omnistaging is going to hit master? |
I am currently experimenting with implementing MPI send / recv as custom XLA calls.
It works fine in most cases, but a function like this leads to a deadlock:
I guess this is because the return value of
Send
is not used in the computational graph, so the whole call is optimized away, despite having side effects.Is there a way to prevent this?
The text was updated successfully, but these errors were encountered: