-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Send / recv proof of concept #2
Conversation
One solution could be to only implement @jax.jit
def foo(x):
y = sendrecv(x, source=0, dest=1)
if rank == 1:
return y
return x since rank 0 doesn't do anything with |
I think that this is because XLA's compiler is very aggressive. I think for this we should ask jag's people if it's somehow possible to tag as 'do_not_optimise' a function. |
By the way, if you rebase, tests should be working now. |
I guess that until Jax#3370 is merged we should rather focus on collective all-to-all communications, which are not affected by the side-effect problem. |
Collective all-to-all operations are affected, too, though. Example: @jax.jit
def foo(x):
x = Allreduce(x)
if rank == 0:
return 0 # kaboom
return x So right now, it is the user’s responsibility to make sure that there is a data dependency on the return value of the MPI calls. |
Allreduce is what I have implemented and, at least in my experience, is working well. |
Ah, ok, I get it. What I meant is that all-to-all are (usually) used in contexts where all ranks execute the same code, so that is (sometimes) not an issue. |
I agree, all-to-all are lower risk. Ultimately it's the user's responsibility not to mess up though, so we should put a warning in the readme or so :) |
The omnistaging and |
Yes, I think that would be sensible. I don't think there's a strong motivation to introduce a bunch of extra logic for JAX versions pre-omnistaging / side effect support. |
What is the tradeoff of omnistaging? |
It works fine if you don't use But yes, when / if omnistaging becomes the default in JAX I think we should just require it. |
This is done from my side. Tests are failing hard until omnistaging is released, but it works on my machine™️ |
When using this I noticed that XLA would sometimes re-order send and recv calls, which causes deadlocks. So unless a solution comes up in jax-ml/jax#3976, this will need some token mechanism to ensure proper order. This principally affects all primitives, but it's easiest to run into with |
Are you testing this with a JaxLib you built yourself, and their staging mechanism?
XLA should hopefully not reorder operations with side-effects, because... they have side effects.
Maybe we should cc them (can’t from my phone...)
Il 7 ago 2020, 12:04 +0200, Dion Häfner <notifications@github.com>, ha scritto:
… When using this I noticed that XLA would sometimes re-order send and recv calls, which causes deadlocks.
So unless a solution comes up in jax-ml/jax#3976, this will need some token mechanism to ensure proper order.
This principally affects all primitives, but it's easiest to run into with send and recv.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub, or unsubscribe.
|
This is current JAX master. I've opened an issue already and the JAX devs confirmed the problem (jax-ml/jax#3976). If there's no fix on the horizon, I had some ideas of a "token tape" that should allow us to get around this with 1 line of extra boilerplate in user code. |
I added a token mechanism which ensures that proper order is conserved. The idea is to chain the calls by using an XLA token: token = Send(...) # not passing a token creates a new one
token = Send(..., token=token) # re-use previous token
arr, token = Recv(..., token=token)
arr, token = Sendrecv(..., token=token)
arr, token = Allreduce(..., token=token) As long as the correct token is passed, those statements should never get re-ordered (relative to each other) or optimized away. It sounds like the JAX people are cooking up a solution that does this token chaining automatically behind the scenes, but my feeling is that this might take a while to land. Left to-do:
|
very well |
How important are |
They are not essential. --
|
I'll patch out |
Ok, thanks for the investigation! That's fine by me. |
I'm back from holidays and starting to work again! How are we with the merging of omnistaging in jax? do you have any news? |
Welcome back! There has been a jaxlib release today, so this should work now :) |
Ah, it's just a tag, not a release... I asked for one. |
Yay! Thanks for bumping the google guys |
Done from my side (for real this time). |
All is great. |
The following script works:
There is one problem though that you have to assign something to the result of the
Send
call, otherwise it gets optimized out and everything deadlocks. I.e., this doesn't work:Not sure if there's anything we can do about that. The whole implementation is pretty hacky with an unnecessary
memcpy
, but I don't think JAX / XLA accounts for custom calls that have side effects.I didn't touch the gradient code (yet).