Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

soft_pmap fails with unmapped args. #3400

Closed
Joshuaalbert opened this issue Jun 11, 2020 · 5 comments
Closed

soft_pmap fails with unmapped args. #3400

Joshuaalbert opened this issue Jun 11, 2020 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@Joshuaalbert
Copy link
Contributor

Joshuaalbert commented Jun 11, 2020

Consider both paths of soft_pmap. First when chunksize = 0.

os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count=6"
import jax
import jax.numpy as jnp
def test_soft_pmap():
    def fun(x,z):
        return (x-z, x+z)
    x = jnp.ones((5,5))
    z = jnp.array([1.])
    pfun = jax.soft_pmap(fun,in_axes=(0,None))
    pfun(x,z)

Following the call we notice that inside soft_pmap that when the mapped axis size is smaller than device count then it simply calls pmap without the in_axes argument, so it gets the default value of 0.
This then results in an error.

Then consider when chunksize>0 (set xla_force_host_platform_device_count=5 in above) and evenly divides the leading dimension. Then inside soft_pmap it tries to reshape all arguments without heeding mapped_invars (which is correctly (True, False)).

Simple solution. On first branch, pass along the in_axes. On second branch only reshape when the corresponding value in mapped_invars is True.

@Joshuaalbert
Copy link
Contributor Author

@mattjj
If you make the following replacements in soft_pmap:

return pmap(fun, axis_name, backend=backend)(*args)

to

return pmap(fun, axis_name, in_axes=in_axes, backend=backend)(*args)

and

reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat]

to

reshaped_args = [_reshape_split(num_chunks, x) if mapped else x for x, mapped in zip(args_flat, mapped_invars)]

then soft_pmap works with unmapped args works for for mapped_axis_size <= device_count but not when mapped_axis_size = L*device_count for some natural number L>1. In this case, it looks like something in pxla.split_axis needs to be modified I think, though I can't be sure.

@Joshuaalbert
Copy link
Contributor Author

Note, a temporary fix is of course to just use partial or similar to pass in the unmapped arguments.

@mattjj
Copy link
Collaborator

mattjj commented Jun 11, 2020

How did you find out about soft_pmap!?!

Just kidding :) It's a bit of an ugly prototype at the moment. #3370 rewrites it, though only as a temporary patch before we make something better.

Since #3370 changes it pretty significantly (for example, _reshape_split is totally gone), I'm tempted to follow up on this issue only after we merge that PR, especially given that you have a workaround. Or alternatively, if you want to work from that branch, we could iterate on the branch together. WDYT?

@mattjj mattjj self-assigned this Jun 11, 2020
@mattjj mattjj added the enhancement New feature or request label Jun 11, 2020
@Joshuaalbert
Copy link
Contributor Author

Joshuaalbert commented Jun 11, 2020

@mattjj Oh, I dig deep when I learn a new code base :) I like the principle of being able to simulate XLA devices for two reasons: the ability to map more elements than available devices available, and to test distributed programs on a single machine before deploying. Seems like there are two ways, a method like soft_pmap, or by a many-to-one mapping of devices. Perhaps I like the second idea more but @skye mentioned that would be significantly harder. Excited to try the new code when it's out. My use case is radio astronomy calibration and imaging. Typical datasets are ~10TB and we need to keep the code as close to data as possible, so we try to avoid multi-machine computing. The machines are typically 512GB with 64 cores and the idea to keep memory as free as possible and use CPUs as much as possible while avoiding IO blocking. I'm going to try using JAX end-to-end on a new type of algorithm. Still assessing if that's possible. Actually, if would be great to skype and get some answers directly! Probably save me weeks of probing. Would you be available?

Also, regarding your offer, I'll wait for the upcoming release and use the workaround for now.

@hawkinsp
Copy link
Collaborator

soft_pmap was deleted in jax 0.3.18. This issue is therefore stale...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants