Skip to content

Commit

Permalink
Update for #5799
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Nov 30, 2023
1 parent 55db8a2 commit c8f7315
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,17 +799,15 @@ std::vector<at::Tensor> XlaDataToTensors(
absl::Span<const at::ScalarType> dest_element_type) {
std::vector<xla::Literal> literals = ReleaseGilAndTransferData(xla_data);
std::vector<at::Tensor> tensors(literals.size());
auto mwait = std::make_shared<runtime::util::MultiWait>(tensors.size());
absl::BlockingCounter counter(literals.size());
for (size_t i = 0; i < tensors.size(); ++i) {
auto copy_fn = [&, i]() {
tensors[i] = MakeTensorFromXlaLiteral(literals[i], dest_element_type[i]);
counter.DecrementCount();
};
// Use an IO closure, since MakeTensorFromXlaLiteral may block on async
// copies for >2D tensors.
runtime::env::ScheduleIoClosure(
runtime::util::MultiWait::Completer(mwait, std::move(copy_fn)));
thread::Schedule(std::move(copy_fn));
}
mwait->Wait();
counter.Wait();
return tensors;
}

Expand Down

0 comments on commit c8f7315

Please sign in to comment.