Does updating inside of jitted batch data parallelism work? #24882
Unanswered
logan-dunbar
asked this question in
Q&A
Replies: 1 comment
-
So I just ran the code in a colab, and it produces the same loss at the end of training, so can I assume it is working correctly? Does it know that when trying to update the replicate sharded params object it needs to wait for all the gradfun() returns and then accumulate them all into the replicated object? I'm really surprised that worked tbh, I thought for sure it would shout about race conditions on the update or produce some bogus values by overwriting, if it really works as advertised then I'm stunned by how easy it has been to parallelize my code across multiple GPUs, and I'm even more stoked that I chose to invest in using Jax :) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Under the "8-Way batch data parallelism" the data is batch sharded while the model is replicated sharded.
The params update happens outside of any jitted function, and I'm assuming the sharding takes care of replicating the changes to all devices.
Say now the update to params is done inside a jitted function, like below, does the update happen correctly, or does the params get overwritten in some unpredictable way because essentially 8 GPUs are trying to update the replicated data with a chunk of the computation?
Beta Was this translation helpful? Give feedback.
All reactions