Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't set $TPU_LIBRARY_PATH during import #5698

Merged
merged 2 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,16 @@ def _aws_ec2_inf_trn_init():


def _setup_tpu_vm_library_path() -> bool:
"""Returns true if $TPU_LIBRARY is set or can be inferred.
"""Returns true if $TPU_LIBRARY_PATH is set or can be inferred.

We load libtpu.so in the following order of precedence:

1. User-set $TPU_LIBRARY_PATH
2. libtpu.so included in torch_xla/lib
3. libtpu-nightly pip package

Sets $PTXLA_TPU_LIBRARY_PATH if path is inferred by us to prevent conflicts
with other frameworks. This env var will be removed in a future version.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder under what condition can we remove this env var.

"""
if 'TPU_LIBRARY_PATH' in os.environ:
return True
Expand All @@ -87,12 +90,12 @@ def _setup_tpu_vm_library_path() -> bool:
bundled_libtpu_path = os.path.join(module_path, 'lib/libtpu.so')
if os.path.isfile(bundled_libtpu_path) and not os.getenv('TPU_LIBRARY_PATH'):
logger.info('Using bundled libtpu.so (%s)', bundled_libtpu_path)
os.environ['TPU_LIBRARY_PATH'] = bundled_libtpu_path
os.environ['PTXLA_TPU_LIBRARY_PATH'] = bundled_libtpu_path
return True

try:
import libtpu
libtpu.configure_library_path()
os.environ['PTXLA_TPU_LIBRARY_PATH'] = libtpu.get_library_path()
return True
except ImportError:
return False
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/env_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const char* const kEnvPjRtTpuMaxInflightComputations =
const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT";
const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT";
const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH";
const char* const kEnvInferredTpuLibraryPath = "PTXLA_TPU_LIBRARY_PATH";
const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH";
const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH";
const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR";
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/env_vars.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ extern const char* const kEnvPjRtTpuMaxInflightComputations;
extern const char* const kEnvPjrtAsyncCpuClient;
extern const char* const kEnvPjrtAsyncGpuClient;
extern const char* const kEnvTpuLibraryPath;
extern const char* const kEnvInferredTpuLibraryPath;
extern const char* const kEnvXpuLibraryPath;
extern const char* const kEnvNeuronLibraryPath;
extern const char* const kEnvPjrtDistServiceAddr;
Expand Down
9 changes: 5 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ PjRtComputationClient::PjRtComputationClient() {
client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value());
} else if (device_type == "TPU" || device_type == "TPU_C_API") {
TF_VLOG(1) << "Initializing TFRT TPU client...";
XLA_CHECK_OK(
pjrt::LoadPjrtPlugin(
"tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))
.status());
// Prefer $TPU_LIBRARY_PATH if set
auto tpu_library_path = sys_util::GetEnvString(
env::kEnvTpuLibraryPath,
sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so"));
XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status());
tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
XLA_CHECK(tpu_status.ok());
client_ = std::move(xla::GetCApiClient("TPU").value());
Expand Down