diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index eeaa2aaba0c9..2d5226870486 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -32,11 +32,26 @@ def _setup_xla_flags(): os.environ['XLA_FLAGS'] = ' '.join(flags) +def _setup_libtpu_flags(): + flags = os.environ.get('LIBTPU_INIT_ARGS', '').split(' ') + # This flag will rerun the latency hidding scheduler if the default + # shared memory limit 95% leads to OOM. Each rerun will choose a value + # 0.9x of the previous run, and the number of rerun is set to 1 now. + # Shared memory limit refers to --xla_tpu_scheduler_percent_shared_memory_limit. + # Lower shared memory limit means less communiation and computation overlapping, + # and thus worse performance. + flags = _set_missing_flags(flags, + (('xla_latency_hiding_scheduler_rerun', '1'),)) + os.environ['LIBTPU_INIT_ARGS'] = ' '.join(flags) + + def _setup_default_env(): os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') if tpu.num_available_chips() > 0: + _setup_libtpu_flags() + os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA')