Skip to content

Commit

Permalink
Default JAX_CPU_COLLECTIVES_IMPLEMENTATION to 'gloo'.
Browse files Browse the repository at this point in the history
This enables CPU collectives by default, making multi-process CPU
communication work without extra configuration.

PiperOrigin-RevId: 722209646
  • Loading branch information
skye authored and Google-ML-Automation committed Feb 4, 2025
1 parent bc1a706 commit 4b3585e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
env vars. Before they could only be specified via jax.config or flags.
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` now defaults to `'gloo'`, meaning
multi-process CPU communication works out-of-the-box.

## jax 0.5.0 (Jan 17, 2025)

Expand Down
11 changes: 8 additions & 3 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@

MIN_COMPUTE_CAPABILITY = 52

_DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo'

# TODO(phawkins): Remove jax_xla_backend.
_XLA_BACKEND = config.string_flag(
'jax_xla_backend', '',
Expand Down Expand Up @@ -235,7 +237,9 @@ def make_cpu_client(
Returns:
The created CPU client.
"""
if collectives is None:
# TODO(skyewm): use distributed.is_initialized() after
# https://github.com/jax-ml/jax/pull/26172 goes in.
if collectives is None and distributed.global_state.client is not None:
collectives_impl = config.cpu_collectives_implementation.value
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
collectives_impl = 'gloo'
Expand All @@ -244,6 +248,9 @@ def make_cpu_client(
'"jax_cpu_collectives_implementation", "gloo")` instead.',
DeprecationWarning,
)
if collectives_impl is None:
collectives_impl = _DEFAULT_CPU_COLLECTIVES_IMPL

if collectives_impl == 'gloo':
collectives = xla_client._xla.make_gloo_tcp_collectives(
distributed_client=distributed.global_state.client,
Expand All @@ -252,8 +259,6 @@ def make_cpu_client(
collectives = xla_client._xla.make_mpi_collectives()
collectives.Init()
atexit.register(collectives.Finalize)
elif collectives_impl == 'megascale':
raise ValueError('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"')
else:
# Already validated by config module
assert collectives_impl is None
Expand Down

0 comments on commit 4b3585e

Please sign in to comment.