diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index d4eed1ec83f..8a38c87cd00 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -118,10 +118,8 @@ ptxla_cc_library( ":layout_manager", ":shape_builder", ":shape_helper", - "//torch_xla/csrc/runtime:async_task", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:stablehlo_helper", - "//torch_xla/csrc/runtime:unique", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 7723d6d95d9..e601034790b 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -17,7 +18,6 @@ #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { @@ -61,14 +61,14 @@ std::string DebugUtil::GetTensorsGraphHlo( absl::Span tensors, const std::vector* indices, bool dump_stablehlo) { std::vector root_values; - runtime::util::Unique unique_device; + std::optional device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; torch::lazy::Value ir_value = tensor->CurrentIrValue(); if (ir_value) { root_values.push_back(std::move(ir_value)); - unique_device.set(tensor->GetDevice()); + device = tensor->GetDevice(); } } } else { @@ -76,13 +76,13 @@ std::string DebugUtil::GetTensorsGraphHlo( torch::lazy::Value ir_value = tensor->CurrentIrValue(); if (ir_value) { root_values.push_back(std::move(ir_value)); - unique_device.set(tensor->GetDevice()); + device = tensor->GetDevice(); } } } - return DumpUtil::ToHlo( - root_values, unique_device ? *unique_device : bridge::GetCurrentDevice(), - EmitMode::kStableHloReadable); + return DumpUtil::ToHlo(root_values, + device.value_or(bridge::GetCurrentDevice()), + EmitMode::kStableHloReadable); } std::string DebugUtil::GetTensorsGraphInfo( @@ -91,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo( std::vector root_nodes; std::vector root_values; std::vector root_hashes; - runtime::util::Unique unique_device; + std::optional device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; @@ -100,7 +100,7 @@ std::string DebugUtil::GetTensorsGraphInfo( root_nodes.push_back(ir_value.node.get()); root_hashes.push_back(ir_value.hash()); root_values.push_back(std::move(ir_value)); - unique_device.set(tensor->GetDevice()); + device = tensor->GetDevice(); } } } else { @@ -110,7 +110,7 @@ std::string DebugUtil::GetTensorsGraphInfo( root_nodes.push_back(ir_value.node.get()); root_hashes.push_back(ir_value.hash()); root_values.push_back(std::move(ir_value)); - unique_device.set(tensor->GetDevice()); + device = tensor->GetDevice(); } } } @@ -137,14 +137,12 @@ std::string DebugUtil::GetTensorsGraphInfo( } else if (format == GraphFormat::kDot) { graph_str = DumpUtil::ToDot(root_nodes); } else if (format == GraphFormat::kHlo) { - graph_str = DumpUtil::ToHlo(root_values, unique_device - ? *unique_device - : bridge::GetCurrentDevice()); + graph_str = DumpUtil::ToHlo(root_values, + device.value_or(bridge::GetCurrentDevice())); } else if (format == GraphFormat::kStableHlo) { - graph_str = DumpUtil::ToHlo( - root_values, - unique_device ? *unique_device : bridge::GetCurrentDevice(), - EmitMode::kStableHloReadable); + graph_str = DumpUtil::ToHlo(root_values, + device.value_or(bridge::GetCurrentDevice()), + EmitMode::kStableHloReadable); } else { XLA_ERROR() << "Invalid graph format: " << format; } diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 9b549312bba..2878f1c4b1d 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -268,15 +268,6 @@ cc_library( ], ) -cc_library( - name = "unique", - hdrs = ["unique.h"], - deps = [ - ":debug_macros", - "@com_google_absl//absl/types:optional", - ], -) - cc_library( name = "util", srcs = ["util.cc"], diff --git a/torch_xla/csrc/runtime/unique.h b/torch_xla/csrc/runtime/unique.h deleted file mode 100644 index f50e24320d9..00000000000 --- a/torch_xla/csrc/runtime/unique.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef XLA_CLIENT_UNIQUE_H_ -#define XLA_CLIENT_UNIQUE_H_ - -#include -#include - -#include "absl/types/optional.h" -#include "torch_xla/csrc/runtime/debug_macros.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -// Helper class to allow tracking zero or more things, which should be forcibly -// be one only thing. -template > -class Unique { - public: - std::pair set(const T& value) { - if (value_) { - XLA_CHECK(C()(*value_, value)) - << "'" << *value_ << "' vs '" << value << "'"; - return std::pair(false, *value_); - } - value_ = value; - return std::pair(true, *value_); - } - - operator bool() const { return value_.has_value(); } - operator const T&() const { return *value_; } - const T& operator*() const { return *value_; } - const T* operator->() const { return value_.operator->(); } - - std::set AsSet() const { - std::set vset; - if (value_.has_value()) { - vset.insert(*value_); - } - return vset; - } - - private: - absl::optional value_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_UNIQUE_H_ diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index e14b11882a7..85378de0f8b 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -38,7 +39,6 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 39d866358ac..659dbfa8834 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -47,7 +48,6 @@ #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" @@ -534,12 +534,12 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { tsl::profiler::TraceMe activity("CollectSyncTensors", tsl::profiler::TraceMeLevel::kInfo); - runtime::util::Unique unique_device; + std::optional device; for (size_t i = 0; i < tensors.size(); ++i) { - unique_device.set(tensors[i]->GetDevice()); + device = tensors[i]->GetDevice(); } SyncTensorCollection coll; - if (!unique_device) { + if (!device) { return coll; } @@ -552,7 +552,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( // graph with on/off force_ltc_data should not match, hash wise. coll.hash = torch::lazy::MHash(config.force_ltc_data); coll.config = config; - coll.device = *unique_device; + coll.device = *device; coll.indices.reserve(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&