From 2d35638d9bbe03d5a76ea2944062d44b3f873d43 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 26 Oct 2023 09:22:59 -0700 Subject: [PATCH] Set --xla_latency_hiding_scheduler_rerun to 1 (#5736) Summary: This flag will rerun the latency hidding scheduler if the default shared memory limit 95% leads to OOM. Each rerun will choose a value 0.9x of the previous run, and the number of rerun is set to 1 now. Shared memory limit refers to --xla_tpu_scheduler_percent_shared_memory_limit. Lower shared memory limit means less communiation and computation overlapping, and thus worse performance. Test Plan: Tested on Llama 2 7B on V4-32. --- torch_xla/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index eeaa2aaba0c9..2d5226870486 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -32,11 +32,26 @@ def _setup_xla_flags(): os.environ['XLA_FLAGS'] = ' '.join(flags) +def _setup_libtpu_flags(): + flags = os.environ.get('LIBTPU_INIT_ARGS', '').split(' ') + # This flag will rerun the latency hidding scheduler if the default + # shared memory limit 95% leads to OOM. Each rerun will choose a value + # 0.9x of the previous run, and the number of rerun is set to 1 now. + # Shared memory limit refers to --xla_tpu_scheduler_percent_shared_memory_limit. + # Lower shared memory limit means less communiation and computation overlapping, + # and thus worse performance. + flags = _set_missing_flags(flags, + (('xla_latency_hiding_scheduler_rerun', '1'),)) + os.environ['LIBTPU_INIT_ARGS'] = ' '.join(flags) + + def _setup_default_env(): os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') if tpu.num_available_chips() > 0: + _setup_libtpu_flags() + os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA')