Skip to content

Commit

Permalink
Add default flag xla_tpu_prefer_async_allgather_to_allreduce (#6528)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored and bhavya01 committed Apr 22, 2024
1 parent 77569df commit 67abdba
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,20 @@ def _setup_libtpu_flags():
flags = _set_missing_flags(flags,
(('xla_latency_hiding_scheduler_rerun', '1'),))

# This flag will prevent AllGather decomposition into AllReduce by the
# compiler when async AllGather is enabled. Decomposed AllGathers are
# persisted in-memory and shared between the forward and backward passes,
# which can result in the entire model's parameters being in device memory.
# However, regular AllGathers are instead rematerialized in the backward pass,
# and when they are async this incurs little overhead but significantly
# improves device memory usage.
flags = _set_missing_flags(
flags, (('xla_tpu_prefer_async_allgather_to_allreduce', 'true'),))

if tpu.version() == 5:
default_v5_flags = {
# Enable async collectives
# TODO(jonbolin): Tune these flags for async collective fusion - v5
# requires continuation fusion to run async collectives.
'xla_enable_async_all_gather': 'true',
'xla_enable_async_collective_permute': 'true',
}
Expand Down

0 comments on commit 67abdba

Please sign in to comment.