Running multiple functions in parallel in JAX #25630
Unanswered
tomas-teijeiro
asked this question in
Q&A
Replies: 1 comment
-
I am not sure if this is optimal for your case, but you can try something like import os
import jax
import jax.numpy as jnp
import functools as ft
from jax.sharding import PartitionSpec as P
from jax.experimental.shard_map import shard_map
import numpy.testing as npt
mesh = jax.make_mesh([2], ["P"])
def f1(x, y):
return x @ y
def f2(x, y):
return -x @ y
@ft.partial(
shard_map,
mesh=mesh,
in_specs=(P(), P()),
out_specs=P("P"),
)
def sharded_f(x, y):
index = jax.lax.axis_index("P")
return jnp.where(index == 0, f1(x, y), f2(x, y))
def unsharded_f(x, y):
return jnp.vstack([f1(x, y), f2(x, y)])
D = 2048
x = jnp.ones((D, D))
y = jnp.ones((D, D))
npt.assert_allclose(sharded_f(x, y), unsharded_f(x, y), atol=1e-4) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have a bunch of JAX-jitted functions that I want to run in parallel with the same input (a JAX array). However, I couldn't find any proper way of doing this, neither within the JAX framework or using standard Python multiprocessing APIs. This is a minimum working example of what I want:
Everything works well in this example, but the execution time of
run_multiple_functions
is the sum of the execution time of all of them. If I try to use aProcessPoolExecutor
insiderun_multiple_functions
, the code breaks due to a pickle problem, and if I use aThreadPoolExecutor
I don't get any benefit due to the GIL.Any ideas on how I could do this? Many thanks in advance!!
Beta Was this translation helpful? Give feedback.
All reactions