-
Notifications
You must be signed in to change notification settings - Fork 480
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "Remove some unused code from
csrc/runtime
(#5785)"
This reverts commit 79557cc.
- Loading branch information
Showing
20 changed files
with
703 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#ifndef XLA_CLIENT_ASYNC_TASK_H_ | ||
#define XLA_CLIENT_ASYNC_TASK_H_ | ||
|
||
#include <condition_variable> | ||
#include <exception> | ||
#include <functional> | ||
#include <memory> | ||
#include <mutex> | ||
|
||
#include "absl/types/optional.h" | ||
#include "torch_xla/csrc/runtime/debug_macros.h" | ||
#include "torch_xla/csrc/runtime/thread_pool.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
namespace util { | ||
|
||
template <typename T> | ||
class AsyncTask { | ||
struct Data { | ||
Data(std::function<T()> taskfn) : taskfn(std::move(taskfn)) {} | ||
|
||
std::function<T()> taskfn; | ||
std::mutex mutex; | ||
std::condition_variable cv; | ||
bool scheduled = false; | ||
bool completed = false; | ||
absl::optional<T> result; | ||
std::exception_ptr exptr; | ||
}; | ||
|
||
public: | ||
explicit AsyncTask(std::function<T()> taskfn) | ||
: data_(std::make_shared<Data>(std::move(taskfn))) {} | ||
|
||
AsyncTask& Wait() { | ||
std::unique_lock<std::mutex> lock(data_->mutex); | ||
XLA_CHECK(data_->scheduled); | ||
data_->cv.wait(lock, [this] { return data_->completed; }); | ||
if (data_->exptr != nullptr) { | ||
std::rethrow_exception(data_->exptr); | ||
} | ||
return *this; | ||
} | ||
|
||
AsyncTask& Schedule() { | ||
auto completer = [data = data_]() { | ||
absl::optional<T> result; | ||
std::exception_ptr exptr; | ||
try { | ||
result = data->taskfn(); | ||
} catch (...) { | ||
exptr = std::current_exception(); | ||
} | ||
|
||
std::lock_guard<std::mutex> lock(data->mutex); | ||
if (result) { | ||
data->result = std::move(*result); | ||
} else { | ||
data->exptr = std::move(exptr); | ||
} | ||
data->completed = true; | ||
data->cv.notify_all(); | ||
}; | ||
|
||
{ | ||
std::lock_guard<std::mutex> lock(data_->mutex); | ||
XLA_CHECK(!data_->scheduled); | ||
data_->scheduled = true; | ||
} | ||
torch_xla::runtime::env::ScheduleIoClosure(std::move(completer)); | ||
return *this; | ||
} | ||
|
||
const T& GetValue() const { | ||
std::lock_guard<std::mutex> lock(data_->mutex); | ||
return *data_->result; | ||
} | ||
|
||
T ConsumeValue() { | ||
std::lock_guard<std::mutex> lock(data_->mutex); | ||
return std::move(*data_->result); | ||
} | ||
|
||
private: | ||
std::shared_ptr<Data> data_; | ||
}; | ||
|
||
} // namespace util | ||
} // namespace runtime | ||
} // namespace torch_xla | ||
|
||
#endif // XLA_CLIENT_ASYNC_TASK_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
#include "torch_xla/csrc/runtime/async_task.h" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <stdexcept> | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
TEST(AsyncTaskTest, BaseTest) { | ||
auto taskfn = []() -> int { return 17; }; | ||
|
||
torch_xla::runtime::util::AsyncTask<int> async(std::move(taskfn)); | ||
async.Schedule(); | ||
async.Wait(); | ||
EXPECT_EQ(async.GetValue(), 17); | ||
} | ||
|
||
TEST(AsyncTaskTest, ExceptionTest) { | ||
auto taskfn = []() -> int { throw std::runtime_error("Task Exception"); }; | ||
|
||
torch_xla::runtime::util::AsyncTask<int> async(std::move(taskfn)); | ||
async.Schedule(); | ||
bool got_exception = false; | ||
try { | ||
async.Wait(); | ||
} catch (const std::exception&) { | ||
got_exception = true; | ||
} | ||
EXPECT_TRUE(got_exception); | ||
} | ||
|
||
TEST(AsyncTaskTest, NoResultCopyTest) { | ||
struct Result { | ||
Result(int* counter) : counter(counter) {} | ||
Result(const Result& ref) : counter(ref.counter) { ++(*counter); } | ||
|
||
Result& operator=(const Result& ref) { | ||
if (this != &ref) { | ||
counter = ref.counter; | ||
++(*counter); | ||
} | ||
return *this; | ||
} | ||
|
||
Result(Result&&) = default; | ||
Result& operator=(Result&&) = default; | ||
|
||
int* counter = nullptr; | ||
}; | ||
|
||
int copy_counter = 0; | ||
auto taskfn = [&]() -> Result { return Result(©_counter); }; | ||
|
||
torch_xla::runtime::util::AsyncTask<Result> async(std::move(taskfn)); | ||
async.Schedule(); | ||
async.Wait(); | ||
|
||
Result result = async.ConsumeValue(); | ||
EXPECT_EQ(copy_counter, 0); | ||
EXPECT_EQ(result.counter, ©_counter); | ||
} | ||
|
||
} // namespace runtime | ||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#include "torch_xla/csrc/runtime/nccl_distributed.h" | ||
|
||
#include <map> | ||
#include <mutex> | ||
|
||
#include "absl/strings/str_join.h" | ||
#include "torch_xla/csrc/runtime/debug_macros.h" | ||
#if XLA_CUDA | ||
#include "third_party/nccl/nccl.h" | ||
#endif | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
namespace nccl_detail { | ||
|
||
#if XLA_CUDA | ||
|
||
namespace { | ||
|
||
class NcclUidManager { | ||
public: | ||
static NcclUidManager* Get(); | ||
|
||
std::string GetNcclUniqueUid(absl::Span<const int64_t> replicas); | ||
|
||
private: | ||
std::mutex mutex_; | ||
std::map<std::string, std::string> replicas_uid_map_; | ||
}; | ||
|
||
NcclUidManager* NcclUidManager::Get() { | ||
static NcclUidManager* nccl_mgr = new NcclUidManager(); | ||
return nccl_mgr; | ||
} | ||
|
||
std::string NcclUidManager::GetNcclUniqueUid( | ||
absl::Span<const int64_t> replicas) { | ||
std::string replicas_str = absl::StrJoin(replicas, ","); | ||
std::lock_guard<std::mutex> lock(mutex_); | ||
auto it = replicas_uid_map_.find(replicas_str); | ||
if (it == replicas_uid_map_.end()) { | ||
ncclUniqueId id; | ||
ncclResult_t r = ncclGetUniqueId(&id); | ||
XLA_CHECK_EQ(r, ncclSuccess) | ||
<< "NCCL UID generation failed: replicas=(" << replicas_str | ||
<< "), error: " << ncclGetErrorString(r); | ||
it = replicas_uid_map_ | ||
.emplace(std::move(replicas_str), | ||
std::string(id.internal, NCCL_UNIQUE_ID_BYTES)) | ||
.first; | ||
} | ||
return it->second; | ||
} | ||
|
||
} // namespace | ||
|
||
std::string GetNcclUniqueUid(absl::Span<const int64_t> replicas) { | ||
return NcclUidManager::Get()->GetNcclUniqueUid(replicas); | ||
} | ||
|
||
#else // XLA_CUDA | ||
|
||
std::string GetNcclUniqueUid(absl::Span<const int64_t> replicas) { | ||
XLA_ERROR() << "Calling GetNcclUniqueUid() without NCCL configuration"; | ||
} | ||
|
||
#endif // XLA_CUDA | ||
|
||
} // namespace nccl_detail | ||
} // namespace runtime | ||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#ifndef XLA_CLIENT_NCCL_DISTRIBUTED_H_ | ||
#define XLA_CLIENT_NCCL_DISTRIBUTED_H_ | ||
|
||
#include <string> | ||
|
||
#include "absl/types/span.h" | ||
#include "xla/types.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
namespace nccl_detail { | ||
|
||
std::string GetNcclUniqueUid(absl::Span<const int64_t> replicas); | ||
|
||
} // namespace nccl_detail | ||
} // namespace runtime | ||
} // namespace torch_xla | ||
|
||
#endif // XLA_CLIENT_NCCL_DISTRIBUTED_H_ |
Oops, something went wrong.