diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 08beb49b7b2..77f2b0c5768 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -27,7 +27,8 @@ namespace { bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a, torch::lazy::BackendDataPtr b, at::ScalarType element_type) { - std::vector tensors = XlaDataToTensors({a, b}, element_type); + std::vector tensors = + XlaDataToTensors({a, b}, {element_type, element_type}); return TensorCompare(tensors[0], tensors[1]); } } // namespace diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4db190c76b8..2e1c5d00792 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1784,8 +1784,8 @@ void InitXlaModuleBindings(py::module m) { shard_handles) { shards.push_back( XlaDataToTensors({shard_handle}, - MaybeUpcastToHostTorchType( - shard_handle->shape().element_type())) + {MaybeUpcastToHostTorchType( + shard_handle->shape().element_type())}) .front()); str_devices.push_back(shard_handle->device()); } diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 5d4d6825e03..f48d3a83db3 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -467,7 +467,8 @@ at::Tensor XLATensor::ToTensor(bool detached) { XLAGraphExecutor::Get()->DeviceBarrier(GetDevice()); // The GetXlaData() call will trigger an ApplyPendingGraph() if an IR // XlaNode is available on the tensor. - std::vector tensors = XlaDataToTensors({GetXlaData()}, dtype()); + std::vector tensors = + XlaDataToTensors({GetXlaData()}, {dtype()}); tensor = std::move(tensors.front()); if (!detached) { SetTensorData(tensor); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index e46bf7e022c..c2183f7e785 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -796,13 +796,18 @@ std::vector ReleaseGilAndTransferData( std::vector XlaDataToTensors( absl::Span xla_data, - at::ScalarType dest_element_type) { + absl::Span dest_element_type) { std::vector literals = ReleaseGilAndTransferData(xla_data); - std::vector tensors; - tensors.reserve(literals.size()); - for (auto& literal : literals) { - tensors.push_back(MakeTensorFromXlaLiteral(literal, dest_element_type)); + std::vector tensors(literals.size()); + absl::BlockingCounter counter(literals.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + auto copy_fn = [&, i]() { + tensors[i] = MakeTensorFromXlaLiteral(literals[i], dest_element_type[i]); + counter.DecrementCount(); + }; + thread::Schedule(std::move(copy_fn)); } + counter.Wait(); return tensors; } diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 81b4cd9a565..f9ca29f7ab1 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -34,7 +34,7 @@ std::vector ReleaseGilAndTransferData( // TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice std::vector XlaDataToTensors( absl::Span xla_data, - at::ScalarType dest_element_type); + absl::Span dest_element_type); bool TensorCompare(const at::Tensor& t1, const at::Tensor& t2); diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index 4adb9f50eae..c2cb2f43289 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -93,7 +93,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { const torch::lazy::BackendDataPtr data, c10::optional logical_scalar_type) const override { // TODO(JackCaoG): handle the logical_scalar_type == nullptr case - return XlaDataToTensors({data}, *logical_scalar_type)[0]; + return XlaDataToTensors({data}, {*logical_scalar_type})[0]; } std::unique_ptr CreateLoweringContext(