From 90d8774154c84be8395c420c58eff553943651bc Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 25 Oct 2023 22:17:35 +0000 Subject: [PATCH 1/4] Set --xla_latency_hiding_scheduler_rerun to 1 Summary: This flag will rerun the latency hidding scheduler if the default shared memory limit 95% leads to OOM. Each rerun will choose a value 0.9x of the previous run, and the number of rerun is set to 1 now. Shared memory limit refers to --xla_tpu_scheduler_percent_shared_memory_limit. Lower shared memory limit means less communiation and computation overlapping, and thus worse performance. Test Plan: Tested on Llama 2 7B on V4-32. --- torch_xla/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index eeaa2aaba0c..0c5a85241f8 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -29,6 +29,12 @@ def _setup_xla_flags(): flags, (('xla_gpu_simplify_all_fp_conversions', 'false'),)) flags = _set_missing_flags(flags, (('xla_gpu_force_compilation_parallelism', '8'),)) + # This flag will rerun the latency hidding scheduler if the default + # shared memory limit 95% leads to OOM. Each rerun will choose a value + # 0.9x of the previous run, and the number of rerun is set to 1 now. + # Shared memory limit refers to --xla_tpu_scheduler_percent_shared_memory_limit. + flags = _set_missing_flags(flags, + (('xla_latency_hiding_scheduler_rerun', '1'),)) os.environ['XLA_FLAGS'] = ' '.join(flags) From a0e0d3d80450e2a6be12781aaa97bd1e0f3dc0c4 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 25 Oct 2023 22:20:37 +0000 Subject: [PATCH 2/4] Adds more comment --- torch_xla/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 0c5a85241f8..a51fd0dfc3f 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -33,6 +33,8 @@ def _setup_xla_flags(): # shared memory limit 95% leads to OOM. Each rerun will choose a value # 0.9x of the previous run, and the number of rerun is set to 1 now. # Shared memory limit refers to --xla_tpu_scheduler_percent_shared_memory_limit. + # Lower shared memory limit means less communiation and computation overlapping, + # and thus worse performance. flags = _set_missing_flags(flags, (('xla_latency_hiding_scheduler_rerun', '1'),)) os.environ['XLA_FLAGS'] = ' '.join(flags) From c7f477b7bca6523787dc15d8a405383c4b954174 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 26 Oct 2023 00:37:51 +0000 Subject: [PATCH 3/4] Introduce LIBTPU_INIT_ARGS --- torch_xla/__init__.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index a51fd0dfc3f..1b48b07e7f8 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -29,6 +29,10 @@ def _setup_xla_flags(): flags, (('xla_gpu_simplify_all_fp_conversions', 'false'),)) flags = _set_missing_flags(flags, (('xla_gpu_force_compilation_parallelism', '8'),)) + os.environ['XLA_FLAGS'] = ' '.join(flags) + +def _setup_libtpu_flags(): + flags = os.environ.get('LIBTPU_INIT_ARGS', '').split(' ') # This flag will rerun the latency hidding scheduler if the default # shared memory limit 95% leads to OOM. Each rerun will choose a value # 0.9x of the previous run, and the number of rerun is set to 1 now. @@ -36,15 +40,16 @@ def _setup_xla_flags(): # Lower shared memory limit means less communiation and computation overlapping, # and thus worse performance. flags = _set_missing_flags(flags, - (('xla_latency_hiding_scheduler_rerun', '1'),)) - os.environ['XLA_FLAGS'] = ' '.join(flags) - + (('xla_latency_hiding_scheduler_rerun', '1'),)) + os.environ['LIBTPU_INIT_ARGS'] = ' '.join(flags) def _setup_default_env(): os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') if tpu.num_available_chips() > 0: + _setup_libtpu_flags() + os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') From 4afa5d50698e83e9b5975e1beb635d14a7c7be61 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 26 Oct 2023 01:06:23 +0000 Subject: [PATCH 4/4] Fix linters --- torch_xla/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 1b48b07e7f8..2d522687048 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -31,6 +31,7 @@ def _setup_xla_flags(): (('xla_gpu_force_compilation_parallelism', '8'),)) os.environ['XLA_FLAGS'] = ' '.join(flags) + def _setup_libtpu_flags(): flags = os.environ.get('LIBTPU_INIT_ARGS', '').split(' ') # This flag will rerun the latency hidding scheduler if the default @@ -40,9 +41,10 @@ def _setup_libtpu_flags(): # Lower shared memory limit means less communiation and computation overlapping, # and thus worse performance. flags = _set_missing_flags(flags, - (('xla_latency_hiding_scheduler_rerun', '1'),)) + (('xla_latency_hiding_scheduler_rerun', '1'),)) os.environ['LIBTPU_INIT_ARGS'] = ' '.join(flags) + def _setup_default_env(): os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')