diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 827d72ff6c7b..45ed57b1b04c 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -410,7 +410,40 @@ std::vector XLAGraphExecutor::GetTensors( std::vector* tensors) { TF_VLOG(4) << "Trying to get the value of " << tensors->size() << " tensor(s)"; - return GetTensorsFused(tensors); + SyncTensorsConfig config; + config.force_ltc_data = false; + auto async = SyncTensorsGraphInternal(tensors, {}, config); + if (async != nullptr) { + async->mwait.Wait(); + } + std::vector tensors_data = GatherTensorsXlaData( + *tensors, async != nullptr ? async->indices : absl::Span(), + async != nullptr ? async->tensors_data + : absl::Span()); + + // Execution is async in PJRT, so TransferFromServer may block until execution + // completes. Release the GIL so other threads can proceed and unblock any + // collective computations. + // HACK: This method may be called outside of python (mainly in C++ tests) or + // when the GIL is already released, so we must check both cases here. If + // possible, prefer to release the GIL in the python bindings before copying + // this pattern. + PyThreadState* save = nullptr; + // TODO(wcromar): Remove this setting when we are more confident + static const bool release_gil = + runtime::sys_util::GetEnvBool("XLA_RELEASE_GIL_DURING_TRANSFER", true); + if (release_gil && Py_IsInitialized() && PyGILState_Check()) { + save = PyEval_SaveThread(); + } + std::vector literals = + runtime::GetComputationClient()->TransferFromServer( + UnwrapXlaData(tensors_data)); + if (save) { + PyEval_RestoreThread(save); + } + + return FetchTensors(tensors, literals, + async != nullptr ? &async->indices : nullptr); } torch::lazy::hash_t XLAGraphExecutor::GetGraphHash( @@ -798,44 +831,6 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( return WrapXlaData(result_data); } -std::vector XLAGraphExecutor::GetTensorsFused( - std::vector* tensors) { - SyncTensorsConfig config; - config.force_ltc_data = false; - auto async = SyncTensorsGraphInternal(tensors, {}, config); - if (async != nullptr) { - async->mwait.Wait(); - } - std::vector tensors_data = GatherTensorsXlaData( - *tensors, async != nullptr ? async->indices : absl::Span(), - async != nullptr ? async->tensors_data - : absl::Span()); - - // Execution is async in PJRT, so TransferFromServer may block until execution - // completes. Release the GIL so other threads can proceed and unblock any - // collective computations. - // HACK: This method may be called outside of python (mainly in C++ tests) or - // when the GIL is already released, so we must check both cases here. If - // possible, prefer to release the GIL in the python bindings before copying - // this pattern. - PyThreadState* save = nullptr; - // TODO(wcromar): Remove this setting when we are more confident - static const bool release_gil = - runtime::sys_util::GetEnvBool("XLA_RELEASE_GIL_DURING_TRANSFER", true); - if (release_gil && Py_IsInitialized() && PyGILState_Check()) { - save = PyEval_SaveThread(); - } - std::vector literals = - runtime::GetComputationClient()->TransferFromServer( - UnwrapXlaData(tensors_data)); - if (save) { - PyEval_RestoreThread(save); - } - - return FetchTensors(tensors, literals, - async != nullptr ? &async->indices : nullptr); -} - std::vector XLAGraphExecutor::GatherTensorsXlaData( const std::vector& tensors, absl::Span indices, absl::Span tensors_data) { diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 0798853ecf0d..829f63a806a4 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -258,9 +258,6 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // Override to enable SPMD. void TensorCollectionBarrier(SyncTensorCollection* coll) final; - // We don't use upstream GetTensorsFused as we have xla::Literal. - std::vector GetTensorsFused(std::vector* tensors); - // Gathers the XLA device data for all the input tensors, after an // asynchronous operation. // TODO(alanwaketan): Reuse the upstream one once Functionalization is done.