From 5b6abde3d647719581cd14cf2c7e996373bc305e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 9 Nov 2023 19:17:54 +0000 Subject: [PATCH 1/9] delete nccl_distributed --- torch_xla/csrc/runtime/BUILD | 14 ----- torch_xla/csrc/runtime/nccl_distributed.cc | 71 ---------------------- torch_xla/csrc/runtime/nccl_distributed.h | 19 ------ 3 files changed, 104 deletions(-) delete mode 100644 torch_xla/csrc/runtime/nccl_distributed.cc delete mode 100644 torch_xla/csrc/runtime/nccl_distributed.h diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index da12492d2cd..39fb8bd6420 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -201,20 +201,6 @@ cc_library( ], ) -cc_library( - name = "nccl_distributed", - srcs = ["nccl_distributed.cc"], - hdrs = ["nccl_distributed.h"], - deps = [ - ":debug_macros", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@xla//xla:types", - ] + if_cuda_is_configured([ - "@local_config_nccl//:nccl", - ]), -) - cc_library( name = "profiler", srcs = ["profiler.cc"], diff --git a/torch_xla/csrc/runtime/nccl_distributed.cc b/torch_xla/csrc/runtime/nccl_distributed.cc deleted file mode 100644 index 51088913b88..00000000000 --- a/torch_xla/csrc/runtime/nccl_distributed.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "torch_xla/csrc/runtime/nccl_distributed.h" - -#include -#include - -#include "absl/strings/str_join.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#if XLA_CUDA -#include "third_party/nccl/nccl.h" -#endif - -namespace torch_xla { -namespace runtime { -namespace nccl_detail { - -#if XLA_CUDA - -namespace { - -class NcclUidManager { - public: - static NcclUidManager* Get(); - - std::string GetNcclUniqueUid(absl::Span replicas); - - private: - std::mutex mutex_; - std::map replicas_uid_map_; -}; - -NcclUidManager* NcclUidManager::Get() { - static NcclUidManager* nccl_mgr = new NcclUidManager(); - return nccl_mgr; -} - -std::string NcclUidManager::GetNcclUniqueUid( - absl::Span replicas) { - std::string replicas_str = absl::StrJoin(replicas, ","); - std::lock_guard lock(mutex_); - auto it = replicas_uid_map_.find(replicas_str); - if (it == replicas_uid_map_.end()) { - ncclUniqueId id; - ncclResult_t r = ncclGetUniqueId(&id); - XLA_CHECK_EQ(r, ncclSuccess) - << "NCCL UID generation failed: replicas=(" << replicas_str - << "), error: " << ncclGetErrorString(r); - it = replicas_uid_map_ - .emplace(std::move(replicas_str), - std::string(id.internal, NCCL_UNIQUE_ID_BYTES)) - .first; - } - return it->second; -} - -} // namespace - -std::string GetNcclUniqueUid(absl::Span replicas) { - return NcclUidManager::Get()->GetNcclUniqueUid(replicas); -} - -#else // XLA_CUDA - -std::string GetNcclUniqueUid(absl::Span replicas) { - XLA_ERROR() << "Calling GetNcclUniqueUid() without NCCL configuration"; -} - -#endif // XLA_CUDA - -} // namespace nccl_detail -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/nccl_distributed.h b/torch_xla/csrc/runtime/nccl_distributed.h deleted file mode 100644 index de5e0b0887d..00000000000 --- a/torch_xla/csrc/runtime/nccl_distributed.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef XLA_CLIENT_NCCL_DISTRIBUTED_H_ -#define XLA_CLIENT_NCCL_DISTRIBUTED_H_ - -#include - -#include "absl/types/span.h" -#include "xla/types.h" - -namespace torch_xla { -namespace runtime { -namespace nccl_detail { - -std::string GetNcclUniqueUid(absl::Span replicas); - -} // namespace nccl_detail -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_NCCL_DISTRIBUTED_H_ From 460cf120bae4d83cb0a4ac108a3794093226ebe9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 9 Nov 2023 19:30:19 +0000 Subject: [PATCH 2/9] remove async_task --- torch_xla/csrc/runtime/BUILD | 20 ----- torch_xla/csrc/runtime/async_task.h | 93 ----------------------- torch_xla/csrc/runtime/async_task_test.cc | 65 ---------------- torch_xla/csrc/xla_graph_executor.h | 1 - 4 files changed, 179 deletions(-) delete mode 100644 torch_xla/csrc/runtime/async_task.h delete mode 100644 torch_xla/csrc/runtime/async_task_test.cc diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 39fb8bd6420..9b549312bba 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -7,26 +7,6 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) -cc_library( - name = "async_task", - hdrs = ["async_task.h"], - deps = [ - ":debug_macros", - ":thread_pool", - "@com_google_absl//absl/types:optional", - ], -) - -cc_test( - name = "async_task_test", - size = "small", - srcs = ["async_task_test.cc"], - deps = [ - ":async_task", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "runtime", srcs = [ diff --git a/torch_xla/csrc/runtime/async_task.h b/torch_xla/csrc/runtime/async_task.h deleted file mode 100644 index 73d923e0eb2..00000000000 --- a/torch_xla/csrc/runtime/async_task.h +++ /dev/null @@ -1,93 +0,0 @@ -#ifndef XLA_CLIENT_ASYNC_TASK_H_ -#define XLA_CLIENT_ASYNC_TASK_H_ - -#include -#include -#include -#include -#include - -#include "absl/types/optional.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/thread_pool.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -template -class AsyncTask { - struct Data { - Data(std::function taskfn) : taskfn(std::move(taskfn)) {} - - std::function taskfn; - std::mutex mutex; - std::condition_variable cv; - bool scheduled = false; - bool completed = false; - absl::optional result; - std::exception_ptr exptr; - }; - - public: - explicit AsyncTask(std::function taskfn) - : data_(std::make_shared(std::move(taskfn))) {} - - AsyncTask& Wait() { - std::unique_lock lock(data_->mutex); - XLA_CHECK(data_->scheduled); - data_->cv.wait(lock, [this] { return data_->completed; }); - if (data_->exptr != nullptr) { - std::rethrow_exception(data_->exptr); - } - return *this; - } - - AsyncTask& Schedule() { - auto completer = [data = data_]() { - absl::optional result; - std::exception_ptr exptr; - try { - result = data->taskfn(); - } catch (...) { - exptr = std::current_exception(); - } - - std::lock_guard lock(data->mutex); - if (result) { - data->result = std::move(*result); - } else { - data->exptr = std::move(exptr); - } - data->completed = true; - data->cv.notify_all(); - }; - - { - std::lock_guard lock(data_->mutex); - XLA_CHECK(!data_->scheduled); - data_->scheduled = true; - } - torch_xla::runtime::env::ScheduleIoClosure(std::move(completer)); - return *this; - } - - const T& GetValue() const { - std::lock_guard lock(data_->mutex); - return *data_->result; - } - - T ConsumeValue() { - std::lock_guard lock(data_->mutex); - return std::move(*data_->result); - } - - private: - std::shared_ptr data_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_ASYNC_TASK_H_ diff --git a/torch_xla/csrc/runtime/async_task_test.cc b/torch_xla/csrc/runtime/async_task_test.cc deleted file mode 100644 index 9b7a98c5a1f..00000000000 --- a/torch_xla/csrc/runtime/async_task_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include "torch_xla/csrc/runtime/async_task.h" - -#include - -#include - -namespace torch_xla { -namespace runtime { - -TEST(AsyncTaskTest, BaseTest) { - auto taskfn = []() -> int { return 17; }; - - torch_xla::runtime::util::AsyncTask async(std::move(taskfn)); - async.Schedule(); - async.Wait(); - EXPECT_EQ(async.GetValue(), 17); -} - -TEST(AsyncTaskTest, ExceptionTest) { - auto taskfn = []() -> int { throw std::runtime_error("Task Exception"); }; - - torch_xla::runtime::util::AsyncTask async(std::move(taskfn)); - async.Schedule(); - bool got_exception = false; - try { - async.Wait(); - } catch (const std::exception&) { - got_exception = true; - } - EXPECT_TRUE(got_exception); -} - -TEST(AsyncTaskTest, NoResultCopyTest) { - struct Result { - Result(int* counter) : counter(counter) {} - Result(const Result& ref) : counter(ref.counter) { ++(*counter); } - - Result& operator=(const Result& ref) { - if (this != &ref) { - counter = ref.counter; - ++(*counter); - } - return *this; - } - - Result(Result&&) = default; - Result& operator=(Result&&) = default; - - int* counter = nullptr; - }; - - int copy_counter = 0; - auto taskfn = [&]() -> Result { return Result(©_counter); }; - - torch_xla::runtime::util::AsyncTask async(std::move(taskfn)); - async.Schedule(); - async.Wait(); - - Result result = async.ConsumeValue(); - EXPECT_EQ(copy_counter, 0); - EXPECT_EQ(result.counter, ©_counter); -} - -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 829f63a806a..c7a870be319 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -16,7 +16,6 @@ #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/runtime/async_task.h" #include "torch_xla/csrc/runtime/cache.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/multi_wait.h" From ebe4567b07b5fca72c6952b8ae30d4f404d3c8e2 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 9 Nov 2023 19:31:58 +0000 Subject: [PATCH 3/9] remove unique --- torch_xla/csrc/BUILD | 2 -- torch_xla/csrc/debug_util.cpp | 32 ++++++++--------- torch_xla/csrc/runtime/BUILD | 9 ----- torch_xla/csrc/runtime/unique.h | 50 --------------------------- torch_xla/csrc/tensor.cpp | 2 +- torch_xla/csrc/xla_graph_executor.cpp | 10 +++--- 6 files changed, 21 insertions(+), 84 deletions(-) delete mode 100644 torch_xla/csrc/runtime/unique.h diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index d4eed1ec83f..8a38c87cd00 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -118,10 +118,8 @@ ptxla_cc_library( ":layout_manager", ":shape_builder", ":shape_helper", - "//torch_xla/csrc/runtime:async_task", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:stablehlo_helper", - "//torch_xla/csrc/runtime:unique", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 7723d6d95d9..e601034790b 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -17,7 +18,6 @@ #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { @@ -61,14 +61,14 @@ std::string DebugUtil::GetTensorsGraphHlo( absl::Span tensors, const std::vector* indices, bool dump_stablehlo) { std::vector root_values; - runtime::util::Unique unique_device; + std::optional device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; torch::lazy::Value ir_value = tensor->CurrentIrValue(); if (ir_value) { root_values.push_back(std::move(ir_value)); - unique_device.set(tensor->GetDevice()); + device = tensor->GetDevice(); } } } else { @@ -76,13 +76,13 @@ std::string DebugUtil::GetTensorsGraphHlo( torch::lazy::Value ir_value = tensor->CurrentIrValue(); if (ir_value) { root_values.push_back(std::move(ir_value)); - unique_device.set(tensor->GetDevice()); + device = tensor->GetDevice(); } } } - return DumpUtil::ToHlo( - root_values, unique_device ? *unique_device : bridge::GetCurrentDevice(), - EmitMode::kStableHloReadable); + return DumpUtil::ToHlo(root_values, + device.value_or(bridge::GetCurrentDevice()), + EmitMode::kStableHloReadable); } std::string DebugUtil::GetTensorsGraphInfo( @@ -91,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo( std::vector root_nodes; std::vector root_values; std::vector root_hashes; - runtime::util::Unique unique_device; + std::optional device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; @@ -100,7 +100,7 @@ std::string DebugUtil::GetTensorsGraphInfo( root_nodes.push_back(ir_value.node.get()); root_hashes.push_back(ir_value.hash()); root_values.push_back(std::move(ir_value)); - unique_device.set(tensor->GetDevice()); + device = tensor->GetDevice(); } } } else { @@ -110,7 +110,7 @@ std::string DebugUtil::GetTensorsGraphInfo( root_nodes.push_back(ir_value.node.get()); root_hashes.push_back(ir_value.hash()); root_values.push_back(std::move(ir_value)); - unique_device.set(tensor->GetDevice()); + device = tensor->GetDevice(); } } } @@ -137,14 +137,12 @@ std::string DebugUtil::GetTensorsGraphInfo( } else if (format == GraphFormat::kDot) { graph_str = DumpUtil::ToDot(root_nodes); } else if (format == GraphFormat::kHlo) { - graph_str = DumpUtil::ToHlo(root_values, unique_device - ? *unique_device - : bridge::GetCurrentDevice()); + graph_str = DumpUtil::ToHlo(root_values, + device.value_or(bridge::GetCurrentDevice())); } else if (format == GraphFormat::kStableHlo) { - graph_str = DumpUtil::ToHlo( - root_values, - unique_device ? *unique_device : bridge::GetCurrentDevice(), - EmitMode::kStableHloReadable); + graph_str = DumpUtil::ToHlo(root_values, + device.value_or(bridge::GetCurrentDevice()), + EmitMode::kStableHloReadable); } else { XLA_ERROR() << "Invalid graph format: " << format; } diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 9b549312bba..2878f1c4b1d 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -268,15 +268,6 @@ cc_library( ], ) -cc_library( - name = "unique", - hdrs = ["unique.h"], - deps = [ - ":debug_macros", - "@com_google_absl//absl/types:optional", - ], -) - cc_library( name = "util", srcs = ["util.cc"], diff --git a/torch_xla/csrc/runtime/unique.h b/torch_xla/csrc/runtime/unique.h deleted file mode 100644 index f50e24320d9..00000000000 --- a/torch_xla/csrc/runtime/unique.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef XLA_CLIENT_UNIQUE_H_ -#define XLA_CLIENT_UNIQUE_H_ - -#include -#include - -#include "absl/types/optional.h" -#include "torch_xla/csrc/runtime/debug_macros.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -// Helper class to allow tracking zero or more things, which should be forcibly -// be one only thing. -template > -class Unique { - public: - std::pair set(const T& value) { - if (value_) { - XLA_CHECK(C()(*value_, value)) - << "'" << *value_ << "' vs '" << value << "'"; - return std::pair(false, *value_); - } - value_ = value; - return std::pair(true, *value_); - } - - operator bool() const { return value_.has_value(); } - operator const T&() const { return *value_; } - const T& operator*() const { return *value_; } - const T* operator->() const { return value_.operator->(); } - - std::set AsSet() const { - std::set vset; - if (value_.has_value()) { - vset.insert(*value_); - } - return vset; - } - - private: - absl::optional value_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_UNIQUE_H_ diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index e14b11882a7..85378de0f8b 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -38,7 +39,6 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 39d866358ac..659dbfa8834 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -47,7 +48,6 @@ #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" @@ -534,12 +534,12 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { tsl::profiler::TraceMe activity("CollectSyncTensors", tsl::profiler::TraceMeLevel::kInfo); - runtime::util::Unique unique_device; + std::optional device; for (size_t i = 0; i < tensors.size(); ++i) { - unique_device.set(tensors[i]->GetDevice()); + device = tensors[i]->GetDevice(); } SyncTensorCollection coll; - if (!unique_device) { + if (!device) { return coll; } @@ -552,7 +552,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( // graph with on/off force_ltc_data should not match, hash wise. coll.hash = torch::lazy::MHash(config.force_ltc_data); coll.config = config; - coll.device = *unique_device; + coll.device = *device; coll.indices.reserve(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { if (tensor_ids.insert(tensors[i]->GetUniqueId()).second && From e1064e01c21cbdd1b4840f457043fc63d133acf3 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 9 Nov 2023 20:03:53 +0000 Subject: [PATCH 4/9] Remove hashing --- torch_xla/csrc/layout_manager.cpp | 2 +- torch_xla/csrc/runtime/BUILD | 8 ++- torch_xla/csrc/runtime/types.h | 2 - torch_xla/csrc/runtime/util.cc | 76 +--------------------- torch_xla/csrc/runtime/util.h | 99 ----------------------------- torch_xla/csrc/runtime/util_test.cc | 15 ----- torch_xla/csrc/runtime/xla_util.cc | 15 +++-- torch_xla/csrc/runtime/xla_util.h | 4 +- torch_xla/csrc/torch_util.cpp | 4 +- 9 files changed, 22 insertions(+), 203 deletions(-) diff --git a/torch_xla/csrc/layout_manager.cpp b/torch_xla/csrc/layout_manager.cpp index 3c54b33e911..3c462f0e53b 100644 --- a/torch_xla/csrc/layout_manager.cpp +++ b/torch_xla/csrc/layout_manager.cpp @@ -41,7 +41,7 @@ class LayoutManager { struct DimensionsHasher { size_t operator()(const absl::Span& dimensions) const { - return runtime::util::HashReduce(runtime::util::MHash(dimensions)); + return torch::lazy::HashReduce(torch::lazy::MHash(std::vector({dimensions.begin(), dimensions.end()}))); } }; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 2878f1c4b1d..dd3efffe318 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -3,6 +3,11 @@ load( "if_cuda_is_configured", ) +load( + "//bazel:rules_def.bzl", + "ptxla_cc_test", +) + licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) @@ -313,10 +318,11 @@ cc_library( "@xla//xla/service:platform_util", "@xla//xla/service/spmd:spmd_partitioner", "@tsl//tsl/platform:errors", + "@torch//:headers", ], ) -cc_test( +ptxla_cc_test( name = "xla_util_test", size = "small", srcs = ["xla_util_test.cc"], diff --git a/torch_xla/csrc/runtime/types.h b/torch_xla/csrc/runtime/types.h index de71cd6c2ef..a27f1a0e1c2 100644 --- a/torch_xla/csrc/runtime/types.h +++ b/torch_xla/csrc/runtime/types.h @@ -11,8 +11,6 @@ namespace torch_xla { namespace runtime { -using hash_t = absl::uint128; - struct Percentile { enum class UnitOfMeaure { kNumber, diff --git a/torch_xla/csrc/runtime/util.cc b/torch_xla/csrc/runtime/util.cc index caeeb149492..280350ea263 100644 --- a/torch_xla/csrc/runtime/util.cc +++ b/torch_xla/csrc/runtime/util.cc @@ -4,80 +4,6 @@ namespace torch_xla { namespace runtime { -namespace util { -namespace { - -hash_t LoadHash(const uint8_t** data, const uint8_t* top) { - std::ptrdiff_t size = top - (*data); - if (size >= sizeof(hash_t)) { - hash_t v; - std::memcpy(&v, *data, sizeof(v)); - *data += sizeof(hash_t); - return v; - } - - union { - hash_t h; - uint8_t b[sizeof(hash_t)]; - } uval; - uval.h = 0; - std::memcpy(uval.b, *data, size); - *data += size; - return uval.h; -} - -} // namespace - -hash_t HashBlock(const void* data, size_t n, const hash_t& seed) { - const hash_t m = 0xc6a4a7935bd1e995; - const int r = 47; - - const uint8_t* u8_data = reinterpret_cast(data); - const uint8_t* top = u8_data + n; - hash_t h = seed ^ (n * m); - while (u8_data < top) { - hash_t k = LoadHash(&u8_data, top); - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - h ^= h >> r; - h *= m; - h ^= h >> r; - return h; -} - -hash_t DataHash(const void* data, size_t size) { - return HashBlock(data, size, 0xc2b2ae3d27d4eb4f); -} - -size_t StdDataHash(const void* data, size_t size) { - return HashReduce(DataHash(data, size)); -} - -size_t StdHashCombine(uintmax_t a, uintmax_t b) { - return a ^ - (b * 0x27d4eb2f165667c5 + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2)); -} - -hash_t HashCombine(const hash_t& a, const hash_t& b) { - static const hash_t kb = absl::MakeUint128(101, 0x27d4eb2f165667c5); - return a ^ (b * kb + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2)); -} - -size_t HashReduce(const hash_t& a) { - return StdHashCombine(absl::Uint128Low64(a), absl::Uint128High64(a)); -} - -std::string HexHash(const hash_t& a) { - std::stringstream ss; - ss << std::hex << a; - return ss.str(); -} - -} // namespace util +namespace util {} // namespace util } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/util.h b/torch_xla/csrc/runtime/util.h index c4d05593d84..047c1d8ea0b 100644 --- a/torch_xla/csrc/runtime/util.h +++ b/torch_xla/csrc/runtime/util.h @@ -24,24 +24,6 @@ namespace torch_xla { namespace runtime { namespace util { -hash_t HashBlock(const void* data, size_t n, const hash_t& seed); - -hash_t DataHash(const void* data, size_t size); - -size_t StdDataHash(const void* data, size_t size); - -size_t StdHashCombine(uintmax_t a, uintmax_t b); - -hash_t HashCombine(const hash_t& a, const hash_t& b); - -size_t HashReduce(const hash_t& a); - -std::string HexHash(const hash_t& a); - -struct HashReducer { - size_t operator()(const hash_t& value) const { return HashReduce(value); } -}; - template xla::Status CheckedCall(const F& fn) { try { @@ -139,28 +121,6 @@ class MaybePtr { absl::optional storage_; }; -// Hasher for string-like objects which hashes only a partial window of the data -// of size N. The P (policy) type is a functor which returns the position of the -// window. -template -struct PartialHasher { - size_t operator()(const T& data) const { - size_t pos = policy(data.size()); - size_t end = pos + N; - if (end > data.size()) { - end = data.size(); - if (N > data.size()) { - pos = 0; - } else { - pos = end - N; - } - } - return tsl::Hash64(data.data() + pos, end - pos, 17); - } - - P policy; -}; - template std::vector GetConstSharedPointers( const C& shared_pointers) { @@ -271,65 +231,6 @@ T Multiply(const S& input) { std::multiplies()); } -static inline hash_t StringHash(const char* data) { - return DataHash(data, std::strlen(data)); -} - -template ::value>::type* = nullptr> -hash_t Hash(const T& value) { - return DataHash(&value, sizeof(value)); -} - -static inline hash_t Hash(const std::string& value) { - return DataHash(value.data(), value.size()); -} - -// Forward declare to allow hashes of vectors of vectors to work. -template -hash_t ContainerHash(const T& values); - -template -hash_t Hash(absl::Span values) { - return ContainerHash(values); -} - -template -hash_t Hash(const std::vector& values) { - return ContainerHash(values); -} - -template -hash_t Hash(const std::set& values) { - return ContainerHash(values); -} - -template -hash_t Hash(const std::pair& values) { - return HashCombine(Hash(values.first), Hash(values.second)); -} - -static inline hash_t Hash(const hash_t& value) { return value; } - -template -hash_t ContainerHash(const T& values) { - hash_t h = 0x85ebca77c2b2ae63; - for (auto& value : values) { - h = HashCombine(h, Hash(value)); - } - return h; -} - -template -hash_t MHash() { - return 0x165667b19e3779f9; -} - -template -hash_t MHash(T value, Targs... Fargs) { - return HashCombine(Hash(value), MHash(Fargs...)); -} - } // namespace util } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/util_test.cc b/torch_xla/csrc/runtime/util_test.cc index 125e8ba8bef..c7a3c5ac285 100644 --- a/torch_xla/csrc/runtime/util_test.cc +++ b/torch_xla/csrc/runtime/util_test.cc @@ -90,21 +90,6 @@ TEST(UtilTest, Multiply) { EXPECT_EQ(Multiply(t), 720); } -TEST(UtilTest, Hash) { - std::pair temp = {"hello", 3}; - EXPECT_EQ(Hash(std::pair{"hello", 3}), Hash(temp)); - EXPECT_EQ(HexHash(Hash(std::pair{"hello", 3})), - HexHash(Hash(temp))); - - std::vector t = {1, 2, 3, 4, 5}; - EXPECT_EQ(Hash({1, 2, 3, 4, 5}), Hash({1, 2, 3, 4, 5})); - EXPECT_EQ(Hash(std::set{1, 2, 3}), Hash(std::set{1, 2, 3})); - EXPECT_EQ(Hash(t), Hash(std::vector{1, 2, 3, 4, 5})); - - EXPECT_EQ(StdDataHash(t.data(), t.size()), - StdDataHash(std::vector{1, 2, 3, 4, 5}.data(), t.size())); -} - TEST(UtilTest, MaybeRef) { using StringRef = torch_xla::runtime::util::MaybeRef; std::string storage("String storage"); diff --git a/torch_xla/csrc/runtime/xla_util.cc b/torch_xla/csrc/runtime/xla_util.cc index e591198bf7e..5eb9e009128 100644 --- a/torch_xla/csrc/runtime/xla_util.cc +++ b/torch_xla/csrc/runtime/xla_util.cc @@ -1,5 +1,7 @@ #include "torch_xla/csrc/runtime/xla_util.h" +#include + #include #include #include @@ -19,16 +21,17 @@ namespace runtime { namespace util { namespace { -hash_t SingleShapeHash(const xla::Shape& shape, hash_t seed) { +torch::lazy::hash_t SingleShapeHash(const xla::Shape& shape, + torch::lazy::hash_t seed) { if (shape.has_layout()) { for (auto dim : shape.layout().minor_to_major()) { - seed = HashCombine(seed, dim); + seed = torch::lazy::HashCombine(seed, dim); } } for (auto dim : shape.dimensions()) { - seed = HashCombine(seed, dim); + seed = torch::lazy::HashCombine(seed, dim); } - return HashCombine(seed, static_cast(shape.element_type())); + return torch::lazy::HashCombine(seed, static_cast(shape.element_type())); } void MaybeSaveHloGraph(const std::string& hlo_text, size_t index) { @@ -103,8 +106,8 @@ void CheckComputationStatus( } } -hash_t ShapeHash(const xla::Shape& shape) { - hash_t hash = 0xa5d2d6916; +torch::lazy::hash_t ShapeHash(const xla::Shape& shape) { + torch::lazy::hash_t hash = 0xa5d2d6916; xla::ShapeUtil::ForEachSubshape( shape, [&](const xla::Shape& subshape, const xla::ShapeIndex&) { hash = SingleShapeHash(subshape, hash); diff --git a/torch_xla/csrc/runtime/xla_util.h b/torch_xla/csrc/runtime/xla_util.h index 32b76f69eb9..3163d5ba8c4 100644 --- a/torch_xla/csrc/runtime/xla_util.h +++ b/torch_xla/csrc/runtime/xla_util.h @@ -1,6 +1,8 @@ #ifndef XLA_CLIENT_XLA_UTIL_H_ #define XLA_CLIENT_XLA_UTIL_H_ +#include + #include #include "absl/types/span.h" @@ -35,7 +37,7 @@ void CheckComputationStatus( absl::Span computations, absl::Span output_shapes); -hash_t ShapeHash(const xla::Shape& shape); +torch::lazy::hash_t ShapeHash(const xla::Shape& shape); } // namespace util } // namespace runtime diff --git a/torch_xla/csrc/torch_util.cpp b/torch_xla/csrc/torch_util.cpp index 2148478006f..1d5e3616643 100644 --- a/torch_xla/csrc/torch_util.cpp +++ b/torch_xla/csrc/torch_util.cpp @@ -78,9 +78,7 @@ at::Tensor MaybeWrapTensorToFunctional(const at::Tensor& tensor) { namespace torch { namespace lazy { torch::lazy::hash_t Hash(const xla::Shape& shape) { - auto shape_hash = torch_xla::runtime::util::ShapeHash(shape); - return c10::uint128(absl::Uint128High64(shape_hash), - absl::Uint128Low64(shape_hash)); + return torch_xla::runtime::util::ShapeHash(shape); } } // namespace lazy } // namespace torch From 461f643d59658edcd4cf6a466434b180cb6b6979 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 9 Nov 2023 20:15:33 +0000 Subject: [PATCH 5/9] more random cleanup --- torch_xla/csrc/runtime/util.h | 102 ---------------------------- torch_xla/csrc/runtime/util_test.cc | 38 ----------- torch_xla/csrc/tensor_util.cpp | 2 +- 3 files changed, 1 insertion(+), 141 deletions(-) diff --git a/torch_xla/csrc/runtime/util.h b/torch_xla/csrc/runtime/util.h index 047c1d8ea0b..722a6591f78 100644 --- a/torch_xla/csrc/runtime/util.h +++ b/torch_xla/csrc/runtime/util.h @@ -24,57 +24,6 @@ namespace torch_xla { namespace runtime { namespace util { -template -xla::Status CheckedCall(const F& fn) { - try { - fn(); - } catch (const std::exception& ex) { - return tsl::errors::Internal(ex.what()); - } - return xla::Status(); -} - -template -class Cleanup { - public: - using StatusType = T; - - explicit Cleanup(std::function func) - : func_(std::move(func)) {} - Cleanup(Cleanup&& ref) - : func_(std::move(ref.func_)), status_(std::move(ref.status_)) {} - Cleanup(const Cleanup&) = delete; - - ~Cleanup() { - if (func_ != nullptr) { - func_(std::move(status_)); - } - } - - Cleanup& operator=(const Cleanup&) = delete; - - Cleanup& operator=(Cleanup&& ref) { - if (this != &ref) { - func_ = std::move(ref.func_); - status_ = std::move(ref.status_); - } - return *this; - } - - void Release() { func_ = nullptr; } - - void SetStatus(StatusType status) { status_ = std::move(status); } - - const StatusType& GetStatus() const { return status_; } - - private: - std::function func_; - StatusType status_; -}; - -using ExceptionCleanup = Cleanup; -using StatusCleanup = Cleanup; - // Allows APIs which might return const references and values, to not be forced // to return values in the signature. template @@ -96,10 +45,6 @@ class MaybeRef { const T& ref_; }; -struct MidPolicy { - size_t operator()(size_t size) const { return size / 2; } -}; - template class MaybePtr { public: @@ -121,48 +66,6 @@ class MaybePtr { absl::optional storage_; }; -template -std::vector GetConstSharedPointers( - const C& shared_pointers) { - std::vector pointers; - pointers.reserve(shared_pointers.size()); - for (auto& shared_pointer : shared_pointers) { - pointers.push_back(shared_pointer.get()); - } - return pointers; -} - -template -std::vector GetSharedPointers( - const C& shared_pointers) { - std::vector pointers; - pointers.reserve(shared_pointers.size()); - for (auto& shared_pointer : shared_pointers) { - pointers.push_back(shared_pointer.get()); - } - return pointers; -} - -template -void InsertCombined(C* map, const K& key, const T& value, const F& combiner) { - auto it = map->find(key); - if (it == map->end()) { - map->emplace(key, value); - } else { - it->second = combiner(it->second, value); - } -} - -template -std::vector Iota(size_t size, T init = 0, T incr = 1) { - std::vector result(size); - T value = init; - for (size_t i = 0; i < size; ++i, value += incr) { - result[i] = value; - } - return result; -} - template std::vector Range(T start, T end, T step = 1) { std::vector result; @@ -220,11 +123,6 @@ const typename T::mapped_type& MapInsert(T* cont, return it->second; } -template -typename std::underlying_type::type GetEnumValue(T value) { - return static_cast::type>(value); -} - template T Multiply(const S& input) { return std::accumulate(input.begin(), input.end(), T(1), diff --git a/torch_xla/csrc/runtime/util_test.cc b/torch_xla/csrc/runtime/util_test.cc index c7a3c5ac285..f65eea30ca7 100644 --- a/torch_xla/csrc/runtime/util_test.cc +++ b/torch_xla/csrc/runtime/util_test.cc @@ -15,36 +15,6 @@ namespace util { using ::testing::ElementsAre; -TEST(UtilTest, Cleanup) { - bool notify = false; - - // Set to true. - { - Cleanup c([¬ify](bool b) { notify = b; }); - c.SetStatus(true); - } - EXPECT_TRUE(notify); - - // Set to false. - { - Cleanup c([¬ify](bool b) { notify = b; }); - c.SetStatus(false); - } - EXPECT_FALSE(notify); - - // Releasing the cleanup will not change the `notify` to true. - { - Cleanup c([¬ify](bool b) { notify = b; }); - c.SetStatus(true); - c.Release(); - } - EXPECT_FALSE(notify); -} - -TEST(UtilTest, Iota) { - EXPECT_THAT(Iota(5, 0, 2), ElementsAre(0, 2, 4, 6, 8)); -} - TEST(UtilTest, Range) { EXPECT_THAT(Range(0, 10, 2), ElementsAre(0, 2, 4, 6, 8)); EXPECT_THAT(Range(10, 0, -2), ElementsAre(10, 8, 6, 4, 2)); @@ -75,14 +45,6 @@ TEST(UtilTest, MapInsert) { EXPECT_EQ(MapInsert(&v, 1, [] { return 12; }), 1); } -TEST(UtilTest, GetEnumValue) { - enum E { A = 0, B, C, D }; - EXPECT_EQ(GetEnumValue(E::A), 0); - EXPECT_EQ(GetEnumValue(E::B), 1); - EXPECT_EQ(GetEnumValue(E::C), 2); - EXPECT_EQ(GetEnumValue(E::D), 3); -} - TEST(UtilTest, Multiply) { std::vector t = {1, 2, 3, 4, 5}; EXPECT_EQ(Multiply(t), 120); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index a419bd98b7e..29697963c8a 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -36,7 +36,7 @@ namespace { struct DataAsync { std::vector source_tensors; std::vector async_datas; - std::vector handle_unlockers; + std::vector handle_unlockers; }; bool ShouldUseBF16() { From 455cc9fa027a385c460e17b3cb3c41611418a73a Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 9 Nov 2023 20:15:51 +0000 Subject: [PATCH 6/9] formatting --- torch_xla/csrc/layout_manager.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/layout_manager.cpp b/torch_xla/csrc/layout_manager.cpp index 3c462f0e53b..b488acbaefc 100644 --- a/torch_xla/csrc/layout_manager.cpp +++ b/torch_xla/csrc/layout_manager.cpp @@ -41,7 +41,8 @@ class LayoutManager { struct DimensionsHasher { size_t operator()(const absl::Span& dimensions) const { - return torch::lazy::HashReduce(torch::lazy::MHash(std::vector({dimensions.begin(), dimensions.end()}))); + return torch::lazy::HashReduce(torch::lazy::MHash( + std::vector({dimensions.begin(), dimensions.end()}))); } }; From ab212cfd6be1c0dac6ea2ea070188d5ca0eb2122 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 10 Nov 2023 17:12:31 +0000 Subject: [PATCH 7/9] remove util.cc --- torch_xla/csrc/runtime/BUILD | 1 - torch_xla/csrc/runtime/util.cc | 9 --------- 2 files changed, 10 deletions(-) delete mode 100644 torch_xla/csrc/runtime/util.cc diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index dd3efffe318..85e1e1557a6 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -275,7 +275,6 @@ cc_library( cc_library( name = "util", - srcs = ["util.cc"], hdrs = ["util.h"], deps = [ ":types", diff --git a/torch_xla/csrc/runtime/util.cc b/torch_xla/csrc/runtime/util.cc deleted file mode 100644 index 280350ea263..00000000000 --- a/torch_xla/csrc/runtime/util.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "torch_xla/csrc/runtime/util.h" - -#include - -namespace torch_xla { -namespace runtime { -namespace util {} // namespace util -} // namespace runtime -} // namespace torch_xla From 0146c1bf83a44e44aa520a1b6fb5a6f0823db4c9 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 10 Nov 2023 17:17:06 +0000 Subject: [PATCH 8/9] Revert "remove unique" This reverts commit ebe4567b07b5fca72c6952b8ae30d4f404d3c8e2. --- torch_xla/csrc/BUILD | 2 ++ torch_xla/csrc/debug_util.cpp | 32 +++++++++-------- torch_xla/csrc/runtime/BUILD | 9 +++++ torch_xla/csrc/runtime/unique.h | 50 +++++++++++++++++++++++++++ torch_xla/csrc/tensor.cpp | 2 +- torch_xla/csrc/xla_graph_executor.cpp | 10 +++--- 6 files changed, 84 insertions(+), 21 deletions(-) create mode 100644 torch_xla/csrc/runtime/unique.h diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 8a38c87cd00..d4eed1ec83f 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -118,8 +118,10 @@ ptxla_cc_library( ":layout_manager", ":shape_builder", ":shape_helper", + "//torch_xla/csrc/runtime:async_task", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:stablehlo_helper", + "//torch_xla/csrc/runtime:unique", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index e601034790b..7723d6d95d9 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include @@ -18,6 +17,7 @@ #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { @@ -61,14 +61,14 @@ std::string DebugUtil::GetTensorsGraphHlo( absl::Span tensors, const std::vector* indices, bool dump_stablehlo) { std::vector root_values; - std::optional device; + runtime::util::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; torch::lazy::Value ir_value = tensor->CurrentIrValue(); if (ir_value) { root_values.push_back(std::move(ir_value)); - device = tensor->GetDevice(); + unique_device.set(tensor->GetDevice()); } } } else { @@ -76,13 +76,13 @@ std::string DebugUtil::GetTensorsGraphHlo( torch::lazy::Value ir_value = tensor->CurrentIrValue(); if (ir_value) { root_values.push_back(std::move(ir_value)); - device = tensor->GetDevice(); + unique_device.set(tensor->GetDevice()); } } } - return DumpUtil::ToHlo(root_values, - device.value_or(bridge::GetCurrentDevice()), - EmitMode::kStableHloReadable); + return DumpUtil::ToHlo( + root_values, unique_device ? *unique_device : bridge::GetCurrentDevice(), + EmitMode::kStableHloReadable); } std::string DebugUtil::GetTensorsGraphInfo( @@ -91,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo( std::vector root_nodes; std::vector root_values; std::vector root_hashes; - std::optional device; + runtime::util::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; @@ -100,7 +100,7 @@ std::string DebugUtil::GetTensorsGraphInfo( root_nodes.push_back(ir_value.node.get()); root_hashes.push_back(ir_value.hash()); root_values.push_back(std::move(ir_value)); - device = tensor->GetDevice(); + unique_device.set(tensor->GetDevice()); } } } else { @@ -110,7 +110,7 @@ std::string DebugUtil::GetTensorsGraphInfo( root_nodes.push_back(ir_value.node.get()); root_hashes.push_back(ir_value.hash()); root_values.push_back(std::move(ir_value)); - device = tensor->GetDevice(); + unique_device.set(tensor->GetDevice()); } } } @@ -137,12 +137,14 @@ std::string DebugUtil::GetTensorsGraphInfo( } else if (format == GraphFormat::kDot) { graph_str = DumpUtil::ToDot(root_nodes); } else if (format == GraphFormat::kHlo) { - graph_str = DumpUtil::ToHlo(root_values, - device.value_or(bridge::GetCurrentDevice())); + graph_str = DumpUtil::ToHlo(root_values, unique_device + ? *unique_device + : bridge::GetCurrentDevice()); } else if (format == GraphFormat::kStableHlo) { - graph_str = DumpUtil::ToHlo(root_values, - device.value_or(bridge::GetCurrentDevice()), - EmitMode::kStableHloReadable); + graph_str = DumpUtil::ToHlo( + root_values, + unique_device ? *unique_device : bridge::GetCurrentDevice(), + EmitMode::kStableHloReadable); } else { XLA_ERROR() << "Invalid graph format: " << format; } diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 85e1e1557a6..b19dc0e717d 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -273,6 +273,15 @@ cc_library( ], ) +cc_library( + name = "unique", + hdrs = ["unique.h"], + deps = [ + ":debug_macros", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "util", hdrs = ["util.h"], diff --git a/torch_xla/csrc/runtime/unique.h b/torch_xla/csrc/runtime/unique.h new file mode 100644 index 00000000000..f50e24320d9 --- /dev/null +++ b/torch_xla/csrc/runtime/unique.h @@ -0,0 +1,50 @@ +#ifndef XLA_CLIENT_UNIQUE_H_ +#define XLA_CLIENT_UNIQUE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "torch_xla/csrc/runtime/debug_macros.h" + +namespace torch_xla { +namespace runtime { +namespace util { + +// Helper class to allow tracking zero or more things, which should be forcibly +// be one only thing. +template > +class Unique { + public: + std::pair set(const T& value) { + if (value_) { + XLA_CHECK(C()(*value_, value)) + << "'" << *value_ << "' vs '" << value << "'"; + return std::pair(false, *value_); + } + value_ = value; + return std::pair(true, *value_); + } + + operator bool() const { return value_.has_value(); } + operator const T&() const { return *value_; } + const T& operator*() const { return *value_; } + const T* operator->() const { return value_.operator->(); } + + std::set AsSet() const { + std::set vset; + if (value_.has_value()) { + vset.insert(*value_); + } + return vset; + } + + private: + absl::optional value_; +}; + +} // namespace util +} // namespace runtime +} // namespace torch_xla + +#endif // XLA_CLIENT_UNIQUE_H_ diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 85378de0f8b..e14b11882a7 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -15,7 +15,6 @@ #include #include #include -#include #include #include #include @@ -39,6 +38,7 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 659dbfa8834..39d866358ac 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include @@ -48,6 +47,7 @@ #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" @@ -534,12 +534,12 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { tsl::profiler::TraceMe activity("CollectSyncTensors", tsl::profiler::TraceMeLevel::kInfo); - std::optional device; + runtime::util::Unique unique_device; for (size_t i = 0; i < tensors.size(); ++i) { - device = tensors[i]->GetDevice(); + unique_device.set(tensors[i]->GetDevice()); } SyncTensorCollection coll; - if (!device) { + if (!unique_device) { return coll; } @@ -552,7 +552,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( // graph with on/off force_ltc_data should not match, hash wise. coll.hash = torch::lazy::MHash(config.force_ltc_data); coll.config = config; - coll.device = *device; + coll.device = *unique_device; coll.indices.reserve(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { if (tensor_ids.insert(tensors[i]->GetUniqueId()).second && From 9b99aeda796a14ddafb3fff1d1e0af3cc1bad1f2 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 10 Nov 2023 17:23:23 +0000 Subject: [PATCH 9/9] Use upstream Unique --- torch_xla/csrc/BUILD | 2 -- torch_xla/csrc/debug_util.cpp | 6 ++-- torch_xla/csrc/runtime/BUILD | 9 ----- torch_xla/csrc/runtime/unique.h | 50 --------------------------- torch_xla/csrc/tensor.cpp | 1 - torch_xla/csrc/xla_graph_executor.cpp | 4 +-- 6 files changed, 5 insertions(+), 67 deletions(-) delete mode 100644 torch_xla/csrc/runtime/unique.h diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index d4eed1ec83f..8a38c87cd00 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -118,10 +118,8 @@ ptxla_cc_library( ":layout_manager", ":shape_builder", ":shape_helper", - "//torch_xla/csrc/runtime:async_task", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:stablehlo_helper", - "//torch_xla/csrc/runtime:unique", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 7723d6d95d9..9959d46f8a2 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -1,6 +1,7 @@ #include "torch_xla/csrc/debug_util.h" #include +#include #include #include @@ -17,7 +18,6 @@ #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { @@ -61,7 +61,7 @@ std::string DebugUtil::GetTensorsGraphHlo( absl::Span tensors, const std::vector* indices, bool dump_stablehlo) { std::vector root_values; - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; @@ -91,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo( std::vector root_nodes; std::vector root_values; std::vector root_hashes; - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index b19dc0e717d..85e1e1557a6 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -273,15 +273,6 @@ cc_library( ], ) -cc_library( - name = "unique", - hdrs = ["unique.h"], - deps = [ - ":debug_macros", - "@com_google_absl//absl/types:optional", - ], -) - cc_library( name = "util", hdrs = ["util.h"], diff --git a/torch_xla/csrc/runtime/unique.h b/torch_xla/csrc/runtime/unique.h deleted file mode 100644 index f50e24320d9..00000000000 --- a/torch_xla/csrc/runtime/unique.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef XLA_CLIENT_UNIQUE_H_ -#define XLA_CLIENT_UNIQUE_H_ - -#include -#include - -#include "absl/types/optional.h" -#include "torch_xla/csrc/runtime/debug_macros.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -// Helper class to allow tracking zero or more things, which should be forcibly -// be one only thing. -template > -class Unique { - public: - std::pair set(const T& value) { - if (value_) { - XLA_CHECK(C()(*value_, value)) - << "'" << *value_ << "' vs '" << value << "'"; - return std::pair(false, *value_); - } - value_ = value; - return std::pair(true, *value_); - } - - operator bool() const { return value_.has_value(); } - operator const T&() const { return *value_; } - const T& operator*() const { return *value_; } - const T* operator->() const { return value_.operator->(); } - - std::set AsSet() const { - std::set vset; - if (value_.has_value()) { - vset.insert(*value_); - } - return vset; - } - - private: - absl::optional value_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_UNIQUE_H_ diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index e14b11882a7..a5dc91d27ce 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -38,7 +38,6 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 39d866358ac..9337f779b4f 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -47,7 +48,6 @@ #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" @@ -534,7 +534,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { tsl::profiler::TraceMe activity("CollectSyncTensors", tsl::profiler::TraceMeLevel::kInfo); - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; for (size_t i = 0; i < tensors.size(); ++i) { unique_device.set(tensors[i]->GetDevice()); }