From c8f7315f5a4591e63dbc3be3f1c2ec2289769009 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Mon, 20 Nov 2023 12:03:39 +0000 Subject: [PATCH] Update for #5799 --- torch_xla/csrc/tensor_util.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index a5022895cf4..c2183f7e785 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -799,17 +799,15 @@ std::vector XlaDataToTensors( absl::Span dest_element_type) { std::vector literals = ReleaseGilAndTransferData(xla_data); std::vector tensors(literals.size()); - auto mwait = std::make_shared(tensors.size()); + absl::BlockingCounter counter(literals.size()); for (size_t i = 0; i < tensors.size(); ++i) { auto copy_fn = [&, i]() { tensors[i] = MakeTensorFromXlaLiteral(literals[i], dest_element_type[i]); + counter.DecrementCount(); }; - // Use an IO closure, since MakeTensorFromXlaLiteral may block on async - // copies for >2D tensors. - runtime::env::ScheduleIoClosure( - runtime::util::MultiWait::Completer(mwait, std::move(copy_fn))); + thread::Schedule(std::move(copy_fn)); } - mwait->Wait(); + counter.Wait(); return tensors; }