Skip to content

Commit

Permalink
Use upstream XLA concurrency utilities (#5799)
Browse files Browse the repository at this point in the history
* Use TSL threadpool

* remove multiwait

* fix test build

* Move threadpool namespace

* formatting

* fix test build

* Use BlockingCounter
  • Loading branch information
will-cromar authored and golechwierowicz committed Jan 12, 2024
1 parent cb282d9 commit d0d8ba3
Show file tree
Hide file tree
Showing 17 changed files with 89 additions and 422 deletions.
4 changes: 2 additions & 2 deletions test/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ ptxla_cc_test(
":torch_xla_test",
"//torch_xla/csrc/runtime:runtime",
"//torch_xla/csrc/runtime:debug_macros",
"//torch_xla/csrc/runtime:multi_wait",
"//torch_xla/csrc/runtime:thread_pool",
"//torch_xla/csrc:tensor",
"//torch_xla/csrc:thread_pool",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest_main",
"@xla//xla:shape_util",
"@xla//xla/client:xla_builder",
Expand Down
12 changes: 6 additions & 6 deletions test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

#include <iostream>

#include "absl/synchronization/blocking_counter.h"
#include "test/cpp/cpp_test_util.h"
#include "test/cpp/torch_xla_test.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/torch_util.h"
#include "xla/client/xla_builder.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -57,7 +57,7 @@ void TestSingleReplication(

std::vector<std::vector<torch_xla::runtime::ComputationClient::DataPtr>>
results(device_strings.size());
torch_xla::runtime::util::MultiWait mwait(device_strings.size());
absl::BlockingCounter counter(device_strings.size());
torch_xla::runtime::ComputationClient::ExecuteComputationOptions exec_options;
for (size_t i = 0; i < device_strings.size(); ++i) {
auto executor = [&, i]() {
Expand All @@ -68,11 +68,11 @@ void TestSingleReplication(
torch_xla::runtime::ComputationClient::Data>(
tensors_data[i])},
device_strings[i], exec_options);
counter.DecrementCount();
};
torch_xla::runtime::env::ScheduleIoClosure(
mwait.Completer(std::move(executor)));
torch_xla::thread::Schedule(std::move(executor));
}
mwait.Wait();
counter.Wait();

for (size_t i = 0; i < results.size(); ++i) {
auto literals =
Expand Down
13 changes: 11 additions & 2 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,14 @@ ptxla_cc_library(
"//torch_xla/csrc/runtime:metrics",
"//torch_xla/csrc/runtime:metrics_analysis",
"//torch_xla/csrc/runtime:metrics_reader",
"//torch_xla/csrc/runtime:multi_wait",
"//torch_xla/csrc/runtime:profiler",
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc/runtime:thread_pool",
"//torch_xla/csrc/runtime:util",
"//torch_xla/csrc/runtime:xla_coordinator",
"//torch_xla/csrc/runtime:xla_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:variant",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/profiler/lib:traceme_encode",
Expand Down Expand Up @@ -320,6 +319,16 @@ cc_library(
],
)

cc_library(
name = "thread_pool",
srcs = ["thread_pool.cc"],
hdrs = ["thread_pool.h"],
deps = [
"//torch_xla/csrc/runtime:sys_util",
"@tsl//tsl/platform:env"
],
)

ptxla_cc_library(
name = "unwrap_data",
srcs = ["unwrap_data.cpp"],
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/variant.h"
#include "pybind11/attr.h"
#include "pybind11/cast.h"
Expand All @@ -43,11 +44,9 @@
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/metrics_analysis.h"
#include "torch_xla/csrc/runtime/metrics_reader.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/runtime/xla_util.h"
Expand Down
23 changes: 2 additions & 21 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,12 @@ cc_library(
":computation_client",
":debug_macros",
":env_vars",
":multi_wait",
":profiler",
":stablehlo_helper",
":tensor_source",
":tf_logging",
":thread_pool",
":xla_coordinator",
"//torch_xla/csrc:thread_pool",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
Expand All @@ -102,6 +101,7 @@ cc_library(
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/platform/cloud:gcs_file_system",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
)
Expand Down Expand Up @@ -187,15 +187,6 @@ cc_library(
],
)

cc_library(
name = "multi_wait",
srcs = ["multi_wait.cc"],
hdrs = ["multi_wait.h"],
deps = [
"@xla//xla:types",
],
)

# Profiler silently fails unless we link these backends
cc_library(
name = "profiler_backends",
Expand Down Expand Up @@ -279,16 +270,6 @@ cc_library(
],
)

cc_library(
name = "thread_pool",
srcs = ["thread_pool.cc"],
hdrs = ["thread_pool.h"],
deps = [
":metrics",
":tf_logging",
],
)

cc_library(
name = "tensor_source",
hdrs = ["tensor_source.h"],
Expand Down
73 changes: 0 additions & 73 deletions torch_xla/csrc/runtime/multi_wait.cc

This file was deleted.

60 changes: 0 additions & 60 deletions torch_xla/csrc/runtime/multi_wait.h

This file was deleted.

31 changes: 14 additions & 17 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
#include <vector>

#include "absl/strings/ascii.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/profiler.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/thread_pool.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
Expand Down Expand Up @@ -620,9 +620,9 @@ PjRtComputationClient::ExecuteComputation(
}
CreateDataHandlesCounter()->AddValue(datas.size());

auto mwait = std::make_shared<util::MultiWait>(1);
auto lockfn = [&, this, device, returned_future = std::move(*returned_future),
timed]() mutable {
thread::Schedule(std::move([&, this, device,
returned_future = std::move(*returned_future),
timed]() mutable {
TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for "
<< device;
// Grab the shared lock and block the `WaitDeviceOps` until buffer is
Expand All @@ -643,9 +643,7 @@ PjRtComputationClient::ExecuteComputation(
timed.reset();
TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished";
});
};

env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn)));
}));

TF_VLOG(1) << "Returning " << datas.size() << " results";
return datas;
Expand All @@ -669,7 +667,7 @@ PjRtComputationClient::ExecuteReplicated(
XLA_CHECK(devices.size() == arguments.size())
<< "ExecuteReplicated over " << devices.size() << " devices, but "
<< arguments.size() << " arguments devices.";
auto mwait_argument = std::make_shared<util::MultiWait>(devices.size());
absl::BlockingCounter counter(devices.size());
std::vector<std::vector<xla::PjRtBuffer*>> argument_handles(devices.size());
{
tsl::profiler::TraceMe activity(
Expand All @@ -690,11 +688,11 @@ PjRtComputationClient::ExecuteReplicated(
buffers.push_back(pjrt_data->buffer.get());
}
argument_handles[i] = std::move(buffers);
counter.DecrementCount();
};
env::ScheduleIoClosure(util::MultiWait::Completer(
mwait_argument, std::move(buffer_converter)));
thread::Schedule(std::move(buffer_converter));
}
mwait_argument->Wait();
counter.Wait();
}

xla::ExecuteOptions execute_options;
Expand Down Expand Up @@ -749,9 +747,9 @@ PjRtComputationClient::ExecuteReplicated(
}
}

auto mwait = std::make_shared<util::MultiWait>(1);
auto lockfn = [&, this, returned_futures = std::move(*returned_futures),
timed]() mutable {
thread::Schedule(std::move([&, this,
returned_futures = std::move(*returned_futures),
timed]() mutable {
// Grab the shared lock and block the `WaitDeviceOps` until buffer is
// ready. Since this is the SPMD code path. There is no points to grab
// devices lock for every individual device.
Expand All @@ -772,8 +770,7 @@ PjRtComputationClient::ExecuteReplicated(
timed.reset();
TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished";
});
};
env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn)));
}));

TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results "
<< "with dimensions [" << absl::StrJoin(dims, ",") << "].";
Expand Down
Loading

0 comments on commit d0d8ba3

Please sign in to comment.