From d8f67f03e4e444a10e5b34c2bcb2b3c726d29377 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 16 Oct 2023 10:35:14 -0700 Subject: [PATCH] Don't set $TPU_LIBRARY_PATH during import (#5698) * Don't set $TPU_LIBRARY_PATH during import * remove chekck for new env var so people don't use it --- torch_xla/__init__.py | 9 ++++++--- torch_xla/csrc/runtime/env_vars.cc | 1 + torch_xla/csrc/runtime/env_vars.h | 1 + torch_xla/csrc/runtime/pjrt_computation_client.cc | 9 +++++---- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index eeaa2aaba0c..ce1787d0fe3 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -72,13 +72,16 @@ def _aws_ec2_inf_trn_init(): def _setup_tpu_vm_library_path() -> bool: - """Returns true if $TPU_LIBRARY is set or can be inferred. + """Returns true if $TPU_LIBRARY_PATH is set or can be inferred. We load libtpu.so in the following order of precedence: 1. User-set $TPU_LIBRARY_PATH 2. libtpu.so included in torch_xla/lib 3. libtpu-nightly pip package + + Sets $PTXLA_TPU_LIBRARY_PATH if path is inferred by us to prevent conflicts + with other frameworks. This env var will be removed in a future version. """ if 'TPU_LIBRARY_PATH' in os.environ: return True @@ -87,12 +90,12 @@ def _setup_tpu_vm_library_path() -> bool: bundled_libtpu_path = os.path.join(module_path, 'lib/libtpu.so') if os.path.isfile(bundled_libtpu_path) and not os.getenv('TPU_LIBRARY_PATH'): logger.info('Using bundled libtpu.so (%s)', bundled_libtpu_path) - os.environ['TPU_LIBRARY_PATH'] = bundled_libtpu_path + os.environ['PTXLA_TPU_LIBRARY_PATH'] = bundled_libtpu_path return True try: import libtpu - libtpu.configure_library_path() + os.environ['PTXLA_TPU_LIBRARY_PATH'] = libtpu.get_library_path() return True except ImportError: return False diff --git a/torch_xla/csrc/runtime/env_vars.cc b/torch_xla/csrc/runtime/env_vars.cc index 42040a9cca5..00ffb1f2a25 100644 --- a/torch_xla/csrc/runtime/env_vars.cc +++ b/torch_xla/csrc/runtime/env_vars.cc @@ -14,6 +14,7 @@ const char* const kEnvPjRtTpuMaxInflightComputations = const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT"; const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT"; const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH"; +const char* const kEnvInferredTpuLibraryPath = "PTXLA_TPU_LIBRARY_PATH"; const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH"; const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH"; const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR"; diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h index e54ba8f72cd..72849003765 100644 --- a/torch_xla/csrc/runtime/env_vars.h +++ b/torch_xla/csrc/runtime/env_vars.h @@ -24,6 +24,7 @@ extern const char* const kEnvPjRtTpuMaxInflightComputations; extern const char* const kEnvPjrtAsyncCpuClient; extern const char* const kEnvPjrtAsyncGpuClient; extern const char* const kEnvTpuLibraryPath; +extern const char* const kEnvInferredTpuLibraryPath; extern const char* const kEnvXpuLibraryPath; extern const char* const kEnvNeuronLibraryPath; extern const char* const kEnvPjrtDistServiceAddr; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 237531d43cf..a4a8af1ac9e 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -109,10 +109,11 @@ PjRtComputationClient::PjRtComputationClient() { client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); } else if (device_type == "TPU" || device_type == "TPU_C_API") { TF_VLOG(1) << "Initializing TFRT TPU client..."; - XLA_CHECK_OK( - pjrt::LoadPjrtPlugin( - "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so")) - .status()); + // Prefer $TPU_LIBRARY_PATH if set + auto tpu_library_path = sys_util::GetEnvString( + env::kEnvTpuLibraryPath, + sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); XLA_CHECK(tpu_status.ok()); client_ = std::move(xla::GetCApiClient("TPU").value());