diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index d4eed1ec83f..8a38c87cd00 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -118,10 +118,8 @@ ptxla_cc_library( ":layout_manager", ":shape_builder", ":shape_helper", - "//torch_xla/csrc/runtime:async_task", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:stablehlo_helper", - "//torch_xla/csrc/runtime:unique", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 7723d6d95d9..9959d46f8a2 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -1,6 +1,7 @@ #include "torch_xla/csrc/debug_util.h" #include +#include #include #include @@ -17,7 +18,6 @@ #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { @@ -61,7 +61,7 @@ std::string DebugUtil::GetTensorsGraphHlo( absl::Span tensors, const std::vector* indices, bool dump_stablehlo) { std::vector root_values; - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; @@ -91,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo( std::vector root_nodes; std::vector root_values; std::vector root_hashes; - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; diff --git a/torch_xla/csrc/layout_manager.cpp b/torch_xla/csrc/layout_manager.cpp index 3c54b33e911..b488acbaefc 100644 --- a/torch_xla/csrc/layout_manager.cpp +++ b/torch_xla/csrc/layout_manager.cpp @@ -41,7 +41,8 @@ class LayoutManager { struct DimensionsHasher { size_t operator()(const absl::Span& dimensions) const { - return runtime::util::HashReduce(runtime::util::MHash(dimensions)); + return torch::lazy::HashReduce(torch::lazy::MHash( + std::vector({dimensions.begin(), dimensions.end()}))); } }; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index da12492d2cd..85e1e1557a6 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -3,30 +3,15 @@ load( "if_cuda_is_configured", ) +load( + "//bazel:rules_def.bzl", + "ptxla_cc_test", +) + licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) -cc_library( - name = "async_task", - hdrs = ["async_task.h"], - deps = [ - ":debug_macros", - ":thread_pool", - "@com_google_absl//absl/types:optional", - ], -) - -cc_test( - name = "async_task_test", - size = "small", - srcs = ["async_task_test.cc"], - deps = [ - ":async_task", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "runtime", srcs = [ @@ -201,20 +186,6 @@ cc_library( ], ) -cc_library( - name = "nccl_distributed", - srcs = ["nccl_distributed.cc"], - hdrs = ["nccl_distributed.h"], - deps = [ - ":debug_macros", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@xla//xla:types", - ] + if_cuda_is_configured([ - "@local_config_nccl//:nccl", - ]), -) - cc_library( name = "profiler", srcs = ["profiler.cc"], @@ -302,18 +273,8 @@ cc_library( ], ) -cc_library( - name = "unique", - hdrs = ["unique.h"], - deps = [ - ":debug_macros", - "@com_google_absl//absl/types:optional", - ], -) - cc_library( name = "util", - srcs = ["util.cc"], hdrs = ["util.h"], deps = [ ":types", @@ -356,10 +317,11 @@ cc_library( "@xla//xla/service:platform_util", "@xla//xla/service/spmd:spmd_partitioner", "@tsl//tsl/platform:errors", + "@torch//:headers", ], ) -cc_test( +ptxla_cc_test( name = "xla_util_test", size = "small", srcs = ["xla_util_test.cc"], diff --git a/torch_xla/csrc/runtime/async_task.h b/torch_xla/csrc/runtime/async_task.h deleted file mode 100644 index 73d923e0eb2..00000000000 --- a/torch_xla/csrc/runtime/async_task.h +++ /dev/null @@ -1,93 +0,0 @@ -#ifndef XLA_CLIENT_ASYNC_TASK_H_ -#define XLA_CLIENT_ASYNC_TASK_H_ - -#include -#include -#include -#include -#include - -#include "absl/types/optional.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/thread_pool.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -template -class AsyncTask { - struct Data { - Data(std::function taskfn) : taskfn(std::move(taskfn)) {} - - std::function taskfn; - std::mutex mutex; - std::condition_variable cv; - bool scheduled = false; - bool completed = false; - absl::optional result; - std::exception_ptr exptr; - }; - - public: - explicit AsyncTask(std::function taskfn) - : data_(std::make_shared(std::move(taskfn))) {} - - AsyncTask& Wait() { - std::unique_lock lock(data_->mutex); - XLA_CHECK(data_->scheduled); - data_->cv.wait(lock, [this] { return data_->completed; }); - if (data_->exptr != nullptr) { - std::rethrow_exception(data_->exptr); - } - return *this; - } - - AsyncTask& Schedule() { - auto completer = [data = data_]() { - absl::optional result; - std::exception_ptr exptr; - try { - result = data->taskfn(); - } catch (...) { - exptr = std::current_exception(); - } - - std::lock_guard lock(data->mutex); - if (result) { - data->result = std::move(*result); - } else { - data->exptr = std::move(exptr); - } - data->completed = true; - data->cv.notify_all(); - }; - - { - std::lock_guard lock(data_->mutex); - XLA_CHECK(!data_->scheduled); - data_->scheduled = true; - } - torch_xla::runtime::env::ScheduleIoClosure(std::move(completer)); - return *this; - } - - const T& GetValue() const { - std::lock_guard lock(data_->mutex); - return *data_->result; - } - - T ConsumeValue() { - std::lock_guard lock(data_->mutex); - return std::move(*data_->result); - } - - private: - std::shared_ptr data_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_ASYNC_TASK_H_ diff --git a/torch_xla/csrc/runtime/async_task_test.cc b/torch_xla/csrc/runtime/async_task_test.cc deleted file mode 100644 index 9b7a98c5a1f..00000000000 --- a/torch_xla/csrc/runtime/async_task_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include "torch_xla/csrc/runtime/async_task.h" - -#include - -#include - -namespace torch_xla { -namespace runtime { - -TEST(AsyncTaskTest, BaseTest) { - auto taskfn = []() -> int { return 17; }; - - torch_xla::runtime::util::AsyncTask async(std::move(taskfn)); - async.Schedule(); - async.Wait(); - EXPECT_EQ(async.GetValue(), 17); -} - -TEST(AsyncTaskTest, ExceptionTest) { - auto taskfn = []() -> int { throw std::runtime_error("Task Exception"); }; - - torch_xla::runtime::util::AsyncTask async(std::move(taskfn)); - async.Schedule(); - bool got_exception = false; - try { - async.Wait(); - } catch (const std::exception&) { - got_exception = true; - } - EXPECT_TRUE(got_exception); -} - -TEST(AsyncTaskTest, NoResultCopyTest) { - struct Result { - Result(int* counter) : counter(counter) {} - Result(const Result& ref) : counter(ref.counter) { ++(*counter); } - - Result& operator=(const Result& ref) { - if (this != &ref) { - counter = ref.counter; - ++(*counter); - } - return *this; - } - - Result(Result&&) = default; - Result& operator=(Result&&) = default; - - int* counter = nullptr; - }; - - int copy_counter = 0; - auto taskfn = [&]() -> Result { return Result(©_counter); }; - - torch_xla::runtime::util::AsyncTask async(std::move(taskfn)); - async.Schedule(); - async.Wait(); - - Result result = async.ConsumeValue(); - EXPECT_EQ(copy_counter, 0); - EXPECT_EQ(result.counter, ©_counter); -} - -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/nccl_distributed.cc b/torch_xla/csrc/runtime/nccl_distributed.cc deleted file mode 100644 index 51088913b88..00000000000 --- a/torch_xla/csrc/runtime/nccl_distributed.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "torch_xla/csrc/runtime/nccl_distributed.h" - -#include -#include - -#include "absl/strings/str_join.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#if XLA_CUDA -#include "third_party/nccl/nccl.h" -#endif - -namespace torch_xla { -namespace runtime { -namespace nccl_detail { - -#if XLA_CUDA - -namespace { - -class NcclUidManager { - public: - static NcclUidManager* Get(); - - std::string GetNcclUniqueUid(absl::Span replicas); - - private: - std::mutex mutex_; - std::map replicas_uid_map_; -}; - -NcclUidManager* NcclUidManager::Get() { - static NcclUidManager* nccl_mgr = new NcclUidManager(); - return nccl_mgr; -} - -std::string NcclUidManager::GetNcclUniqueUid( - absl::Span replicas) { - std::string replicas_str = absl::StrJoin(replicas, ","); - std::lock_guard lock(mutex_); - auto it = replicas_uid_map_.find(replicas_str); - if (it == replicas_uid_map_.end()) { - ncclUniqueId id; - ncclResult_t r = ncclGetUniqueId(&id); - XLA_CHECK_EQ(r, ncclSuccess) - << "NCCL UID generation failed: replicas=(" << replicas_str - << "), error: " << ncclGetErrorString(r); - it = replicas_uid_map_ - .emplace(std::move(replicas_str), - std::string(id.internal, NCCL_UNIQUE_ID_BYTES)) - .first; - } - return it->second; -} - -} // namespace - -std::string GetNcclUniqueUid(absl::Span replicas) { - return NcclUidManager::Get()->GetNcclUniqueUid(replicas); -} - -#else // XLA_CUDA - -std::string GetNcclUniqueUid(absl::Span replicas) { - XLA_ERROR() << "Calling GetNcclUniqueUid() without NCCL configuration"; -} - -#endif // XLA_CUDA - -} // namespace nccl_detail -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/nccl_distributed.h b/torch_xla/csrc/runtime/nccl_distributed.h deleted file mode 100644 index de5e0b0887d..00000000000 --- a/torch_xla/csrc/runtime/nccl_distributed.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef XLA_CLIENT_NCCL_DISTRIBUTED_H_ -#define XLA_CLIENT_NCCL_DISTRIBUTED_H_ - -#include - -#include "absl/types/span.h" -#include "xla/types.h" - -namespace torch_xla { -namespace runtime { -namespace nccl_detail { - -std::string GetNcclUniqueUid(absl::Span replicas); - -} // namespace nccl_detail -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_NCCL_DISTRIBUTED_H_ diff --git a/torch_xla/csrc/runtime/types.h b/torch_xla/csrc/runtime/types.h index de71cd6c2ef..a27f1a0e1c2 100644 --- a/torch_xla/csrc/runtime/types.h +++ b/torch_xla/csrc/runtime/types.h @@ -11,8 +11,6 @@ namespace torch_xla { namespace runtime { -using hash_t = absl::uint128; - struct Percentile { enum class UnitOfMeaure { kNumber, diff --git a/torch_xla/csrc/runtime/unique.h b/torch_xla/csrc/runtime/unique.h deleted file mode 100644 index f50e24320d9..00000000000 --- a/torch_xla/csrc/runtime/unique.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef XLA_CLIENT_UNIQUE_H_ -#define XLA_CLIENT_UNIQUE_H_ - -#include -#include - -#include "absl/types/optional.h" -#include "torch_xla/csrc/runtime/debug_macros.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -// Helper class to allow tracking zero or more things, which should be forcibly -// be one only thing. -template > -class Unique { - public: - std::pair set(const T& value) { - if (value_) { - XLA_CHECK(C()(*value_, value)) - << "'" << *value_ << "' vs '" << value << "'"; - return std::pair(false, *value_); - } - value_ = value; - return std::pair(true, *value_); - } - - operator bool() const { return value_.has_value(); } - operator const T&() const { return *value_; } - const T& operator*() const { return *value_; } - const T* operator->() const { return value_.operator->(); } - - std::set AsSet() const { - std::set vset; - if (value_.has_value()) { - vset.insert(*value_); - } - return vset; - } - - private: - absl::optional value_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_UNIQUE_H_ diff --git a/torch_xla/csrc/runtime/util.cc b/torch_xla/csrc/runtime/util.cc deleted file mode 100644 index caeeb149492..00000000000 --- a/torch_xla/csrc/runtime/util.cc +++ /dev/null @@ -1,83 +0,0 @@ -#include "torch_xla/csrc/runtime/util.h" - -#include - -namespace torch_xla { -namespace runtime { -namespace util { -namespace { - -hash_t LoadHash(const uint8_t** data, const uint8_t* top) { - std::ptrdiff_t size = top - (*data); - if (size >= sizeof(hash_t)) { - hash_t v; - std::memcpy(&v, *data, sizeof(v)); - *data += sizeof(hash_t); - return v; - } - - union { - hash_t h; - uint8_t b[sizeof(hash_t)]; - } uval; - uval.h = 0; - std::memcpy(uval.b, *data, size); - *data += size; - return uval.h; -} - -} // namespace - -hash_t HashBlock(const void* data, size_t n, const hash_t& seed) { - const hash_t m = 0xc6a4a7935bd1e995; - const int r = 47; - - const uint8_t* u8_data = reinterpret_cast(data); - const uint8_t* top = u8_data + n; - hash_t h = seed ^ (n * m); - while (u8_data < top) { - hash_t k = LoadHash(&u8_data, top); - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - h ^= h >> r; - h *= m; - h ^= h >> r; - return h; -} - -hash_t DataHash(const void* data, size_t size) { - return HashBlock(data, size, 0xc2b2ae3d27d4eb4f); -} - -size_t StdDataHash(const void* data, size_t size) { - return HashReduce(DataHash(data, size)); -} - -size_t StdHashCombine(uintmax_t a, uintmax_t b) { - return a ^ - (b * 0x27d4eb2f165667c5 + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2)); -} - -hash_t HashCombine(const hash_t& a, const hash_t& b) { - static const hash_t kb = absl::MakeUint128(101, 0x27d4eb2f165667c5); - return a ^ (b * kb + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2)); -} - -size_t HashReduce(const hash_t& a) { - return StdHashCombine(absl::Uint128Low64(a), absl::Uint128High64(a)); -} - -std::string HexHash(const hash_t& a) { - std::stringstream ss; - ss << std::hex << a; - return ss.str(); -} - -} // namespace util -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/util.h b/torch_xla/csrc/runtime/util.h index c4d05593d84..722a6591f78 100644 --- a/torch_xla/csrc/runtime/util.h +++ b/torch_xla/csrc/runtime/util.h @@ -24,75 +24,6 @@ namespace torch_xla { namespace runtime { namespace util { -hash_t HashBlock(const void* data, size_t n, const hash_t& seed); - -hash_t DataHash(const void* data, size_t size); - -size_t StdDataHash(const void* data, size_t size); - -size_t StdHashCombine(uintmax_t a, uintmax_t b); - -hash_t HashCombine(const hash_t& a, const hash_t& b); - -size_t HashReduce(const hash_t& a); - -std::string HexHash(const hash_t& a); - -struct HashReducer { - size_t operator()(const hash_t& value) const { return HashReduce(value); } -}; - -template -xla::Status CheckedCall(const F& fn) { - try { - fn(); - } catch (const std::exception& ex) { - return tsl::errors::Internal(ex.what()); - } - return xla::Status(); -} - -template -class Cleanup { - public: - using StatusType = T; - - explicit Cleanup(std::function func) - : func_(std::move(func)) {} - Cleanup(Cleanup&& ref) - : func_(std::move(ref.func_)), status_(std::move(ref.status_)) {} - Cleanup(const Cleanup&) = delete; - - ~Cleanup() { - if (func_ != nullptr) { - func_(std::move(status_)); - } - } - - Cleanup& operator=(const Cleanup&) = delete; - - Cleanup& operator=(Cleanup&& ref) { - if (this != &ref) { - func_ = std::move(ref.func_); - status_ = std::move(ref.status_); - } - return *this; - } - - void Release() { func_ = nullptr; } - - void SetStatus(StatusType status) { status_ = std::move(status); } - - const StatusType& GetStatus() const { return status_; } - - private: - std::function func_; - StatusType status_; -}; - -using ExceptionCleanup = Cleanup; -using StatusCleanup = Cleanup; - // Allows APIs which might return const references and values, to not be forced // to return values in the signature. template @@ -114,10 +45,6 @@ class MaybeRef { const T& ref_; }; -struct MidPolicy { - size_t operator()(size_t size) const { return size / 2; } -}; - template class MaybePtr { public: @@ -139,70 +66,6 @@ class MaybePtr { absl::optional storage_; }; -// Hasher for string-like objects which hashes only a partial window of the data -// of size N. The P (policy) type is a functor which returns the position of the -// window. -template -struct PartialHasher { - size_t operator()(const T& data) const { - size_t pos = policy(data.size()); - size_t end = pos + N; - if (end > data.size()) { - end = data.size(); - if (N > data.size()) { - pos = 0; - } else { - pos = end - N; - } - } - return tsl::Hash64(data.data() + pos, end - pos, 17); - } - - P policy; -}; - -template -std::vector GetConstSharedPointers( - const C& shared_pointers) { - std::vector pointers; - pointers.reserve(shared_pointers.size()); - for (auto& shared_pointer : shared_pointers) { - pointers.push_back(shared_pointer.get()); - } - return pointers; -} - -template -std::vector GetSharedPointers( - const C& shared_pointers) { - std::vector pointers; - pointers.reserve(shared_pointers.size()); - for (auto& shared_pointer : shared_pointers) { - pointers.push_back(shared_pointer.get()); - } - return pointers; -} - -template -void InsertCombined(C* map, const K& key, const T& value, const F& combiner) { - auto it = map->find(key); - if (it == map->end()) { - map->emplace(key, value); - } else { - it->second = combiner(it->second, value); - } -} - -template -std::vector Iota(size_t size, T init = 0, T incr = 1) { - std::vector result(size); - T value = init; - for (size_t i = 0; i < size; ++i, value += incr) { - result[i] = value; - } - return result; -} - template std::vector Range(T start, T end, T step = 1) { std::vector result; @@ -260,76 +123,12 @@ const typename T::mapped_type& MapInsert(T* cont, return it->second; } -template -typename std::underlying_type::type GetEnumValue(T value) { - return static_cast::type>(value); -} - template T Multiply(const S& input) { return std::accumulate(input.begin(), input.end(), T(1), std::multiplies()); } -static inline hash_t StringHash(const char* data) { - return DataHash(data, std::strlen(data)); -} - -template ::value>::type* = nullptr> -hash_t Hash(const T& value) { - return DataHash(&value, sizeof(value)); -} - -static inline hash_t Hash(const std::string& value) { - return DataHash(value.data(), value.size()); -} - -// Forward declare to allow hashes of vectors of vectors to work. -template -hash_t ContainerHash(const T& values); - -template -hash_t Hash(absl::Span values) { - return ContainerHash(values); -} - -template -hash_t Hash(const std::vector& values) { - return ContainerHash(values); -} - -template -hash_t Hash(const std::set& values) { - return ContainerHash(values); -} - -template -hash_t Hash(const std::pair& values) { - return HashCombine(Hash(values.first), Hash(values.second)); -} - -static inline hash_t Hash(const hash_t& value) { return value; } - -template -hash_t ContainerHash(const T& values) { - hash_t h = 0x85ebca77c2b2ae63; - for (auto& value : values) { - h = HashCombine(h, Hash(value)); - } - return h; -} - -template -hash_t MHash() { - return 0x165667b19e3779f9; -} - -template -hash_t MHash(T value, Targs... Fargs) { - return HashCombine(Hash(value), MHash(Fargs...)); -} - } // namespace util } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/util_test.cc b/torch_xla/csrc/runtime/util_test.cc index 125e8ba8bef..f65eea30ca7 100644 --- a/torch_xla/csrc/runtime/util_test.cc +++ b/torch_xla/csrc/runtime/util_test.cc @@ -15,36 +15,6 @@ namespace util { using ::testing::ElementsAre; -TEST(UtilTest, Cleanup) { - bool notify = false; - - // Set to true. - { - Cleanup c([¬ify](bool b) { notify = b; }); - c.SetStatus(true); - } - EXPECT_TRUE(notify); - - // Set to false. - { - Cleanup c([¬ify](bool b) { notify = b; }); - c.SetStatus(false); - } - EXPECT_FALSE(notify); - - // Releasing the cleanup will not change the `notify` to true. - { - Cleanup c([¬ify](bool b) { notify = b; }); - c.SetStatus(true); - c.Release(); - } - EXPECT_FALSE(notify); -} - -TEST(UtilTest, Iota) { - EXPECT_THAT(Iota(5, 0, 2), ElementsAre(0, 2, 4, 6, 8)); -} - TEST(UtilTest, Range) { EXPECT_THAT(Range(0, 10, 2), ElementsAre(0, 2, 4, 6, 8)); EXPECT_THAT(Range(10, 0, -2), ElementsAre(10, 8, 6, 4, 2)); @@ -75,14 +45,6 @@ TEST(UtilTest, MapInsert) { EXPECT_EQ(MapInsert(&v, 1, [] { return 12; }), 1); } -TEST(UtilTest, GetEnumValue) { - enum E { A = 0, B, C, D }; - EXPECT_EQ(GetEnumValue(E::A), 0); - EXPECT_EQ(GetEnumValue(E::B), 1); - EXPECT_EQ(GetEnumValue(E::C), 2); - EXPECT_EQ(GetEnumValue(E::D), 3); -} - TEST(UtilTest, Multiply) { std::vector t = {1, 2, 3, 4, 5}; EXPECT_EQ(Multiply(t), 120); @@ -90,21 +52,6 @@ TEST(UtilTest, Multiply) { EXPECT_EQ(Multiply(t), 720); } -TEST(UtilTest, Hash) { - std::pair temp = {"hello", 3}; - EXPECT_EQ(Hash(std::pair{"hello", 3}), Hash(temp)); - EXPECT_EQ(HexHash(Hash(std::pair{"hello", 3})), - HexHash(Hash(temp))); - - std::vector t = {1, 2, 3, 4, 5}; - EXPECT_EQ(Hash({1, 2, 3, 4, 5}), Hash({1, 2, 3, 4, 5})); - EXPECT_EQ(Hash(std::set{1, 2, 3}), Hash(std::set{1, 2, 3})); - EXPECT_EQ(Hash(t), Hash(std::vector{1, 2, 3, 4, 5})); - - EXPECT_EQ(StdDataHash(t.data(), t.size()), - StdDataHash(std::vector{1, 2, 3, 4, 5}.data(), t.size())); -} - TEST(UtilTest, MaybeRef) { using StringRef = torch_xla::runtime::util::MaybeRef; std::string storage("String storage"); diff --git a/torch_xla/csrc/runtime/xla_util.cc b/torch_xla/csrc/runtime/xla_util.cc index e591198bf7e..5eb9e009128 100644 --- a/torch_xla/csrc/runtime/xla_util.cc +++ b/torch_xla/csrc/runtime/xla_util.cc @@ -1,5 +1,7 @@ #include "torch_xla/csrc/runtime/xla_util.h" +#include + #include #include #include @@ -19,16 +21,17 @@ namespace runtime { namespace util { namespace { -hash_t SingleShapeHash(const xla::Shape& shape, hash_t seed) { +torch::lazy::hash_t SingleShapeHash(const xla::Shape& shape, + torch::lazy::hash_t seed) { if (shape.has_layout()) { for (auto dim : shape.layout().minor_to_major()) { - seed = HashCombine(seed, dim); + seed = torch::lazy::HashCombine(seed, dim); } } for (auto dim : shape.dimensions()) { - seed = HashCombine(seed, dim); + seed = torch::lazy::HashCombine(seed, dim); } - return HashCombine(seed, static_cast(shape.element_type())); + return torch::lazy::HashCombine(seed, static_cast(shape.element_type())); } void MaybeSaveHloGraph(const std::string& hlo_text, size_t index) { @@ -103,8 +106,8 @@ void CheckComputationStatus( } } -hash_t ShapeHash(const xla::Shape& shape) { - hash_t hash = 0xa5d2d6916; +torch::lazy::hash_t ShapeHash(const xla::Shape& shape) { + torch::lazy::hash_t hash = 0xa5d2d6916; xla::ShapeUtil::ForEachSubshape( shape, [&](const xla::Shape& subshape, const xla::ShapeIndex&) { hash = SingleShapeHash(subshape, hash); diff --git a/torch_xla/csrc/runtime/xla_util.h b/torch_xla/csrc/runtime/xla_util.h index 32b76f69eb9..3163d5ba8c4 100644 --- a/torch_xla/csrc/runtime/xla_util.h +++ b/torch_xla/csrc/runtime/xla_util.h @@ -1,6 +1,8 @@ #ifndef XLA_CLIENT_XLA_UTIL_H_ #define XLA_CLIENT_XLA_UTIL_H_ +#include + #include #include "absl/types/span.h" @@ -35,7 +37,7 @@ void CheckComputationStatus( absl::Span computations, absl::Span output_shapes); -hash_t ShapeHash(const xla::Shape& shape); +torch::lazy::hash_t ShapeHash(const xla::Shape& shape); } // namespace util } // namespace runtime diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index e14b11882a7..a5dc91d27ce 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -38,7 +38,6 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index a419bd98b7e..29697963c8a 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -36,7 +36,7 @@ namespace { struct DataAsync { std::vector source_tensors; std::vector async_datas; - std::vector handle_unlockers; + std::vector handle_unlockers; }; bool ShouldUseBF16() { diff --git a/torch_xla/csrc/torch_util.cpp b/torch_xla/csrc/torch_util.cpp index 2148478006f..1d5e3616643 100644 --- a/torch_xla/csrc/torch_util.cpp +++ b/torch_xla/csrc/torch_util.cpp @@ -78,9 +78,7 @@ at::Tensor MaybeWrapTensorToFunctional(const at::Tensor& tensor) { namespace torch { namespace lazy { torch::lazy::hash_t Hash(const xla::Shape& shape) { - auto shape_hash = torch_xla::runtime::util::ShapeHash(shape); - return c10::uint128(absl::Uint128High64(shape_hash), - absl::Uint128Low64(shape_hash)); + return torch_xla::runtime::util::ShapeHash(shape); } } // namespace lazy } // namespace torch diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 39d866358ac..9337f779b4f 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -47,7 +48,6 @@ #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" @@ -534,7 +534,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { tsl::profiler::TraceMe activity("CollectSyncTensors", tsl::profiler::TraceMeLevel::kInfo); - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; for (size_t i = 0; i < tensors.size(); ++i) { unique_device.set(tensors[i]->GetDevice()); } diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 829f63a806a..c7a870be319 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -16,7 +16,6 @@ #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/runtime/async_task.h" #include "torch_xla/csrc/runtime/cache.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/multi_wait.h"