diff --git a/CHANGELOG.md b/CHANGELOG.md index fdaa253d1649..9a00e5054833 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Changes * `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as env vars. Before they could only be specified via jax.config or flags. + * `JAX_CPU_COLLECTIVES_IMPLEMENTATION` now defaults to `'gloo'`, meaning + multi-process CPU communication works out-of-the-box. ## jax 0.5.0 (Jan 17, 2025) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 51f00c56dd6a..9ef77913ac06 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -62,6 +62,8 @@ MIN_COMPUTE_CAPABILITY = 52 +_DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo' + # TODO(phawkins): Remove jax_xla_backend. _XLA_BACKEND = config.string_flag( 'jax_xla_backend', '', @@ -235,7 +237,9 @@ def make_cpu_client( Returns: The created CPU client. """ - if collectives is None: + # TODO(skyewm): use distributed.is_initialized() after + # https://github.com/jax-ml/jax/pull/26172 goes in. + if collectives is None and distributed.global_state.client is not None: collectives_impl = config.cpu_collectives_implementation.value if _CPU_ENABLE_GLOO_COLLECTIVES.value: collectives_impl = 'gloo' @@ -244,6 +248,9 @@ def make_cpu_client( '"jax_cpu_collectives_implementation", "gloo")` instead.', DeprecationWarning, ) + if collectives_impl is None: + collectives_impl = _DEFAULT_CPU_COLLECTIVES_IMPL + if collectives_impl == 'gloo': collectives = xla_client._xla.make_gloo_tcp_collectives( distributed_client=distributed.global_state.client, @@ -252,8 +259,6 @@ def make_cpu_client( collectives = xla_client._xla.make_mpi_collectives() collectives.Init() atexit.register(collectives.Finalize) - elif collectives_impl == 'megascale': - raise ValueError('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"') else: # Already validated by config module assert collectives_impl is None