Skip to content

Commit

Permalink
Remove GetTensorsFused since we don;t have opbyop anymore (#5718)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored and golechwierowicz committed Jan 12, 2024
1 parent 05354c4 commit de69894
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 42 deletions.
73 changes: 34 additions & 39 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,40 @@ std::vector<at::Tensor> XLAGraphExecutor::GetTensors(
std::vector<XLATensorPtr>* tensors) {
TF_VLOG(4) << "Trying to get the value of " << tensors->size()
<< " tensor(s)";
return GetTensorsFused(tensors);
SyncTensorsConfig config;
config.force_ltc_data = false;
auto async = SyncTensorsGraphInternal(tensors, {}, config);
if (async != nullptr) {
async->mwait.Wait();
}
std::vector<torch::lazy::BackendDataPtr> tensors_data = GatherTensorsXlaData(
*tensors, async != nullptr ? async->indices : absl::Span<const size_t>(),
async != nullptr ? async->tensors_data
: absl::Span<const torch::lazy::BackendDataPtr>());

// Execution is async in PJRT, so TransferFromServer may block until execution
// completes. Release the GIL so other threads can proceed and unblock any
// collective computations.
// HACK: This method may be called outside of python (mainly in C++ tests) or
// when the GIL is already released, so we must check both cases here. If
// possible, prefer to release the GIL in the python bindings before copying
// this pattern.
PyThreadState* save = nullptr;
// TODO(wcromar): Remove this setting when we are more confident
static const bool release_gil =
runtime::sys_util::GetEnvBool("XLA_RELEASE_GIL_DURING_TRANSFER", true);
if (release_gil && Py_IsInitialized() && PyGILState_Check()) {
save = PyEval_SaveThread();
}
std::vector<xla::Literal> literals =
runtime::GetComputationClient()->TransferFromServer(
UnwrapXlaData(tensors_data));
if (save) {
PyEval_RestoreThread(save);
}

return FetchTensors(tensors, literals,
async != nullptr ? &async->indices : nullptr);
}

torch::lazy::hash_t XLAGraphExecutor::GetGraphHash(
Expand Down Expand Up @@ -798,44 +831,6 @@ std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::ExecuteStablehlo(
return WrapXlaData(result_data);
}

std::vector<at::Tensor> XLAGraphExecutor::GetTensorsFused(
std::vector<XLATensorPtr>* tensors) {
SyncTensorsConfig config;
config.force_ltc_data = false;
auto async = SyncTensorsGraphInternal(tensors, {}, config);
if (async != nullptr) {
async->mwait.Wait();
}
std::vector<torch::lazy::BackendDataPtr> tensors_data = GatherTensorsXlaData(
*tensors, async != nullptr ? async->indices : absl::Span<const size_t>(),
async != nullptr ? async->tensors_data
: absl::Span<const torch::lazy::BackendDataPtr>());

// Execution is async in PJRT, so TransferFromServer may block until execution
// completes. Release the GIL so other threads can proceed and unblock any
// collective computations.
// HACK: This method may be called outside of python (mainly in C++ tests) or
// when the GIL is already released, so we must check both cases here. If
// possible, prefer to release the GIL in the python bindings before copying
// this pattern.
PyThreadState* save = nullptr;
// TODO(wcromar): Remove this setting when we are more confident
static const bool release_gil =
runtime::sys_util::GetEnvBool("XLA_RELEASE_GIL_DURING_TRANSFER", true);
if (release_gil && Py_IsInitialized() && PyGILState_Check()) {
save = PyEval_SaveThread();
}
std::vector<xla::Literal> literals =
runtime::GetComputationClient()->TransferFromServer(
UnwrapXlaData(tensors_data));
if (save) {
PyEval_RestoreThread(save);
}

return FetchTensors(tensors, literals,
async != nullptr ? &async->indices : nullptr);
}

std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::GatherTensorsXlaData(
const std::vector<XLATensorPtr>& tensors, absl::Span<const size_t> indices,
absl::Span<const torch::lazy::BackendDataPtr> tensors_data) {
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,6 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
// Override to enable SPMD.
void TensorCollectionBarrier(SyncTensorCollection* coll) final;

// We don't use upstream GetTensorsFused as we have xla::Literal.
std::vector<at::Tensor> GetTensorsFused(std::vector<XLATensorPtr>* tensors);

// Gathers the XLA device data for all the input tensors, after an
// asynchronous operation.
// TODO(alanwaketan): Reuse the upstream one once Functionalization is done.
Expand Down

0 comments on commit de69894

Please sign in to comment.