Skip to content

Commit

Permalink
Revert "Remove some unused code from csrc/runtime (#5785)"
Browse files Browse the repository at this point in the history
This reverts commit 79557cc.
  • Loading branch information
JackCaoG authored Nov 13, 2023
1 parent 56733fb commit da5391e
Show file tree
Hide file tree
Showing 20 changed files with 703 additions and 28 deletions.
2 changes: 2 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ ptxla_cc_library(
":layout_manager",
":shape_builder",
":shape_helper",
"//torch_xla/csrc/runtime:async_task",
"//torch_xla/csrc/runtime",
"//torch_xla/csrc/runtime:stablehlo_helper",
"//torch_xla/csrc/runtime:unique",
"//torch_xla/csrc/runtime:xla_util",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "torch_xla/csrc/debug_util.h"

#include <torch/csrc/lazy/core/hash.h>
#include <torch/csrc/lazy/core/unique.h>
#include <torch/csrc/lazy/python/python_util.h>

#include <fstream>
Expand All @@ -18,6 +17,7 @@
#include "torch_xla/csrc/ir_dump_util.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/unique.h"
#include "torch_xla/csrc/xla_graph_executor.h"

namespace torch_xla {
Expand Down Expand Up @@ -61,7 +61,7 @@ std::string DebugUtil::GetTensorsGraphHlo(
absl::Span<const XLATensorPtr> tensors, const std::vector<size_t>* indices,
bool dump_stablehlo) {
std::vector<torch::lazy::Value> root_values;
torch::lazy::Unique<torch::lazy::BackendDevice> unique_device;
runtime::util::Unique<torch::lazy::BackendDevice> unique_device;
if (indices != nullptr) {
for (auto index : *indices) {
const XLATensorPtr& tensor = tensors[index];
Expand Down Expand Up @@ -91,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo(
std::vector<const torch::lazy::Node*> root_nodes;
std::vector<torch::lazy::Value> root_values;
std::vector<torch::lazy::hash_t> root_hashes;
torch::lazy::Unique<torch::lazy::BackendDevice> unique_device;
runtime::util::Unique<torch::lazy::BackendDevice> unique_device;
if (indices != nullptr) {
for (auto index : *indices) {
const XLATensorPtr& tensor = tensors[index];
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/layout_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class LayoutManager {

struct DimensionsHasher {
size_t operator()(const absl::Span<const int64_t>& dimensions) const {
return torch::lazy::HashReduce(torch::lazy::MHash(
std::vector<int64_t>({dimensions.begin(), dimensions.end()})));
return runtime::util::HashReduce(runtime::util::MHash(dimensions));
}
};

Expand Down
52 changes: 45 additions & 7 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,30 @@ load(
"if_cuda_is_configured",
)

load(
"//bazel:rules_def.bzl",
"ptxla_cc_test",
)

licenses(["notice"]) # Apache 2.0

package(default_visibility = ["//visibility:public"])

cc_library(
name = "async_task",
hdrs = ["async_task.h"],
deps = [
":debug_macros",
":thread_pool",
"@com_google_absl//absl/types:optional",
],
)

cc_test(
name = "async_task_test",
size = "small",
srcs = ["async_task_test.cc"],
deps = [
":async_task",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "runtime",
srcs = [
Expand Down Expand Up @@ -187,6 +202,20 @@ cc_library(
],
)

cc_library(
name = "nccl_distributed",
srcs = ["nccl_distributed.cc"],
hdrs = ["nccl_distributed.h"],
deps = [
":debug_macros",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@xla//xla:types",
] + if_cuda_is_configured([
"@local_config_nccl//:nccl",
]),
)

cc_library(
name = "profiler",
srcs = ["profiler.cc"],
Expand Down Expand Up @@ -274,8 +303,18 @@ cc_library(
],
)

cc_library(
name = "unique",
hdrs = ["unique.h"],
deps = [
":debug_macros",
"@com_google_absl//absl/types:optional",
],
)

cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
":types",
Expand Down Expand Up @@ -318,11 +357,10 @@ cc_library(
"@xla//xla/service:platform_util",
"@xla//xla/service/spmd:spmd_partitioner",
"@tsl//tsl/platform:errors",
"@torch//:headers",
],
)

ptxla_cc_test(
cc_test(
name = "xla_util_test",
size = "small",
srcs = ["xla_util_test.cc"],
Expand Down
93 changes: 93 additions & 0 deletions torch_xla/csrc/runtime/async_task.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#ifndef XLA_CLIENT_ASYNC_TASK_H_
#define XLA_CLIENT_ASYNC_TASK_H_

#include <condition_variable>
#include <exception>
#include <functional>
#include <memory>
#include <mutex>

#include "absl/types/optional.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/thread_pool.h"

namespace torch_xla {
namespace runtime {
namespace util {

template <typename T>
class AsyncTask {
struct Data {
Data(std::function<T()> taskfn) : taskfn(std::move(taskfn)) {}

std::function<T()> taskfn;
std::mutex mutex;
std::condition_variable cv;
bool scheduled = false;
bool completed = false;
absl::optional<T> result;
std::exception_ptr exptr;
};

public:
explicit AsyncTask(std::function<T()> taskfn)
: data_(std::make_shared<Data>(std::move(taskfn))) {}

AsyncTask& Wait() {
std::unique_lock<std::mutex> lock(data_->mutex);
XLA_CHECK(data_->scheduled);
data_->cv.wait(lock, [this] { return data_->completed; });
if (data_->exptr != nullptr) {
std::rethrow_exception(data_->exptr);
}
return *this;
}

AsyncTask& Schedule() {
auto completer = [data = data_]() {
absl::optional<T> result;
std::exception_ptr exptr;
try {
result = data->taskfn();
} catch (...) {
exptr = std::current_exception();
}

std::lock_guard<std::mutex> lock(data->mutex);
if (result) {
data->result = std::move(*result);
} else {
data->exptr = std::move(exptr);
}
data->completed = true;
data->cv.notify_all();
};

{
std::lock_guard<std::mutex> lock(data_->mutex);
XLA_CHECK(!data_->scheduled);
data_->scheduled = true;
}
torch_xla::runtime::env::ScheduleIoClosure(std::move(completer));
return *this;
}

const T& GetValue() const {
std::lock_guard<std::mutex> lock(data_->mutex);
return *data_->result;
}

T ConsumeValue() {
std::lock_guard<std::mutex> lock(data_->mutex);
return std::move(*data_->result);
}

private:
std::shared_ptr<Data> data_;
};

} // namespace util
} // namespace runtime
} // namespace torch_xla

#endif // XLA_CLIENT_ASYNC_TASK_H_
65 changes: 65 additions & 0 deletions torch_xla/csrc/runtime/async_task_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "torch_xla/csrc/runtime/async_task.h"

#include <gtest/gtest.h>

#include <stdexcept>

namespace torch_xla {
namespace runtime {

TEST(AsyncTaskTest, BaseTest) {
auto taskfn = []() -> int { return 17; };

torch_xla::runtime::util::AsyncTask<int> async(std::move(taskfn));
async.Schedule();
async.Wait();
EXPECT_EQ(async.GetValue(), 17);
}

TEST(AsyncTaskTest, ExceptionTest) {
auto taskfn = []() -> int { throw std::runtime_error("Task Exception"); };

torch_xla::runtime::util::AsyncTask<int> async(std::move(taskfn));
async.Schedule();
bool got_exception = false;
try {
async.Wait();
} catch (const std::exception&) {
got_exception = true;
}
EXPECT_TRUE(got_exception);
}

TEST(AsyncTaskTest, NoResultCopyTest) {
struct Result {
Result(int* counter) : counter(counter) {}
Result(const Result& ref) : counter(ref.counter) { ++(*counter); }

Result& operator=(const Result& ref) {
if (this != &ref) {
counter = ref.counter;
++(*counter);
}
return *this;
}

Result(Result&&) = default;
Result& operator=(Result&&) = default;

int* counter = nullptr;
};

int copy_counter = 0;
auto taskfn = [&]() -> Result { return Result(&copy_counter); };

torch_xla::runtime::util::AsyncTask<Result> async(std::move(taskfn));
async.Schedule();
async.Wait();

Result result = async.ConsumeValue();
EXPECT_EQ(copy_counter, 0);
EXPECT_EQ(result.counter, &copy_counter);
}

} // namespace runtime
} // namespace torch_xla
71 changes: 71 additions & 0 deletions torch_xla/csrc/runtime/nccl_distributed.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include "torch_xla/csrc/runtime/nccl_distributed.h"

#include <map>
#include <mutex>

#include "absl/strings/str_join.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#if XLA_CUDA
#include "third_party/nccl/nccl.h"
#endif

namespace torch_xla {
namespace runtime {
namespace nccl_detail {

#if XLA_CUDA

namespace {

class NcclUidManager {
public:
static NcclUidManager* Get();

std::string GetNcclUniqueUid(absl::Span<const int64_t> replicas);

private:
std::mutex mutex_;
std::map<std::string, std::string> replicas_uid_map_;
};

NcclUidManager* NcclUidManager::Get() {
static NcclUidManager* nccl_mgr = new NcclUidManager();
return nccl_mgr;
}

std::string NcclUidManager::GetNcclUniqueUid(
absl::Span<const int64_t> replicas) {
std::string replicas_str = absl::StrJoin(replicas, ",");
std::lock_guard<std::mutex> lock(mutex_);
auto it = replicas_uid_map_.find(replicas_str);
if (it == replicas_uid_map_.end()) {
ncclUniqueId id;
ncclResult_t r = ncclGetUniqueId(&id);
XLA_CHECK_EQ(r, ncclSuccess)
<< "NCCL UID generation failed: replicas=(" << replicas_str
<< "), error: " << ncclGetErrorString(r);
it = replicas_uid_map_
.emplace(std::move(replicas_str),
std::string(id.internal, NCCL_UNIQUE_ID_BYTES))
.first;
}
return it->second;
}

} // namespace

std::string GetNcclUniqueUid(absl::Span<const int64_t> replicas) {
return NcclUidManager::Get()->GetNcclUniqueUid(replicas);
}

#else // XLA_CUDA

std::string GetNcclUniqueUid(absl::Span<const int64_t> replicas) {
XLA_ERROR() << "Calling GetNcclUniqueUid() without NCCL configuration";
}

#endif // XLA_CUDA

} // namespace nccl_detail
} // namespace runtime
} // namespace torch_xla
19 changes: 19 additions & 0 deletions torch_xla/csrc/runtime/nccl_distributed.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef XLA_CLIENT_NCCL_DISTRIBUTED_H_
#define XLA_CLIENT_NCCL_DISTRIBUTED_H_

#include <string>

#include "absl/types/span.h"
#include "xla/types.h"

namespace torch_xla {
namespace runtime {
namespace nccl_detail {

std::string GetNcclUniqueUid(absl::Span<const int64_t> replicas);

} // namespace nccl_detail
} // namespace runtime
} // namespace torch_xla

#endif // XLA_CLIENT_NCCL_DISTRIBUTED_H_
Loading

0 comments on commit da5391e

Please sign in to comment.