Skip to content

Commit

Permalink
Set --xla_latency_hiding_scheduler_rerun to 1 (#5736)
Browse files Browse the repository at this point in the history
Summary:
This flag will rerun the latency hidding scheduler if the default
shared memory limit 95% leads to OOM. Each rerun will choose a value
0.9x of the previous run, and the number of rerun is set to 1 now.
Shared memory limit refers to --xla_tpu_scheduler_percent_shared_memory_limit.
Lower shared memory limit means less communiation and computation overlapping,
and thus worse performance.

Test Plan:
Tested on Llama 2 7B on V4-32.
  • Loading branch information
alanwaketan authored and golechwierowicz committed Jan 12, 2024
1 parent 88a5495 commit 2d35638
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,26 @@ def _setup_xla_flags():
os.environ['XLA_FLAGS'] = ' '.join(flags)


def _setup_libtpu_flags():
flags = os.environ.get('LIBTPU_INIT_ARGS', '').split(' ')
# This flag will rerun the latency hidding scheduler if the default
# shared memory limit 95% leads to OOM. Each rerun will choose a value
# 0.9x of the previous run, and the number of rerun is set to 1 now.
# Shared memory limit refers to --xla_tpu_scheduler_percent_shared_memory_limit.
# Lower shared memory limit means less communiation and computation overlapping,
# and thus worse performance.
flags = _set_missing_flags(flags,
(('xla_latency_hiding_scheduler_rerun', '1'),))
os.environ['LIBTPU_INIT_ARGS'] = ' '.join(flags)


def _setup_default_env():
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')

if tpu.num_available_chips() > 0:
_setup_libtpu_flags()

os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1')
os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA')

Expand Down

0 comments on commit 2d35638

Please sign in to comment.