-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
soft_pmap fails with unmapped args. #3400
Comments
@mattjj return pmap(fun, axis_name, backend=backend)(*args) to return pmap(fun, axis_name, in_axes=in_axes, backend=backend)(*args) and reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat] to reshaped_args = [_reshape_split(num_chunks, x) if mapped else x for x, mapped in zip(args_flat, mapped_invars)] then |
Note, a temporary fix is of course to just use |
How did you find out about Just kidding :) It's a bit of an ugly prototype at the moment. #3370 rewrites it, though only as a temporary patch before we make something better. Since #3370 changes it pretty significantly (for example, |
@mattjj Oh, I dig deep when I learn a new code base :) I like the principle of being able to simulate XLA devices for two reasons: the ability to map more elements than available devices available, and to test distributed programs on a single machine before deploying. Seems like there are two ways, a method like soft_pmap, or by a many-to-one mapping of devices. Perhaps I like the second idea more but @skye mentioned that would be significantly harder. Excited to try the new code when it's out. My use case is radio astronomy calibration and imaging. Typical datasets are ~10TB and we need to keep the code as close to data as possible, so we try to avoid multi-machine computing. The machines are typically 512GB with 64 cores and the idea to keep memory as free as possible and use CPUs as much as possible while avoiding IO blocking. I'm going to try using JAX end-to-end on a new type of algorithm. Still assessing if that's possible. Actually, if would be great to skype and get some answers directly! Probably save me weeks of probing. Would you be available? Also, regarding your offer, I'll wait for the upcoming release and use the workaround for now. |
|
Consider both paths of
soft_pmap
. First when chunksize = 0.Following the call we notice that inside
soft_pmap
that when the mapped axis size is smaller than device count then it simply callspmap
without thein_axes
argument, so it gets the default value of 0.This then results in an error.
Then consider when chunksize>0 (set
xla_force_host_platform_device_count=5
in above) and evenly divides the leading dimension. Then insidesoft_pmap
it tries to reshape all arguments without heedingmapped_invars
(which is correctly(True, False)
).Simple solution. On first branch, pass along the
in_axes
. On second branch only reshape when the corresponding value inmapped_invars
is True.The text was updated successfully, but these errors were encountered: