Skip to content

Commit

Permalink
remove unique
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 9, 2023
1 parent 460cf12 commit ebe4567
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 84 deletions.
2 changes: 0 additions & 2 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,8 @@ ptxla_cc_library(
":layout_manager",
":shape_builder",
":shape_helper",
"//torch_xla/csrc/runtime:async_task",
"//torch_xla/csrc/runtime",
"//torch_xla/csrc/runtime:stablehlo_helper",
"//torch_xla/csrc/runtime:unique",
"//torch_xla/csrc/runtime:xla_util",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
Expand Down
32 changes: 15 additions & 17 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <fstream>
#include <iostream>
#include <mutex>
#include <optional>
#include <sstream>
#include <unordered_set>

Expand All @@ -17,7 +18,6 @@
#include "torch_xla/csrc/ir_dump_util.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/unique.h"
#include "torch_xla/csrc/xla_graph_executor.h"

namespace torch_xla {
Expand Down Expand Up @@ -61,28 +61,28 @@ std::string DebugUtil::GetTensorsGraphHlo(
absl::Span<const XLATensorPtr> tensors, const std::vector<size_t>* indices,
bool dump_stablehlo) {
std::vector<torch::lazy::Value> root_values;
runtime::util::Unique<torch::lazy::BackendDevice> unique_device;
std::optional<torch::lazy::BackendDevice> device;
if (indices != nullptr) {
for (auto index : *indices) {
const XLATensorPtr& tensor = tensors[index];
torch::lazy::Value ir_value = tensor->CurrentIrValue();
if (ir_value) {
root_values.push_back(std::move(ir_value));
unique_device.set(tensor->GetDevice());
device = tensor->GetDevice();
}
}
} else {
for (auto& tensor : tensors) {
torch::lazy::Value ir_value = tensor->CurrentIrValue();
if (ir_value) {
root_values.push_back(std::move(ir_value));
unique_device.set(tensor->GetDevice());
device = tensor->GetDevice();
}
}
}
return DumpUtil::ToHlo(
root_values, unique_device ? *unique_device : bridge::GetCurrentDevice(),
EmitMode::kStableHloReadable);
return DumpUtil::ToHlo(root_values,
device.value_or(bridge::GetCurrentDevice()),
EmitMode::kStableHloReadable);
}

std::string DebugUtil::GetTensorsGraphInfo(
Expand All @@ -91,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo(
std::vector<const torch::lazy::Node*> root_nodes;
std::vector<torch::lazy::Value> root_values;
std::vector<torch::lazy::hash_t> root_hashes;
runtime::util::Unique<torch::lazy::BackendDevice> unique_device;
std::optional<torch::lazy::BackendDevice> device;
if (indices != nullptr) {
for (auto index : *indices) {
const XLATensorPtr& tensor = tensors[index];
Expand All @@ -100,7 +100,7 @@ std::string DebugUtil::GetTensorsGraphInfo(
root_nodes.push_back(ir_value.node.get());
root_hashes.push_back(ir_value.hash());
root_values.push_back(std::move(ir_value));
unique_device.set(tensor->GetDevice());
device = tensor->GetDevice();
}
}
} else {
Expand All @@ -110,7 +110,7 @@ std::string DebugUtil::GetTensorsGraphInfo(
root_nodes.push_back(ir_value.node.get());
root_hashes.push_back(ir_value.hash());
root_values.push_back(std::move(ir_value));
unique_device.set(tensor->GetDevice());
device = tensor->GetDevice();
}
}
}
Expand All @@ -137,14 +137,12 @@ std::string DebugUtil::GetTensorsGraphInfo(
} else if (format == GraphFormat::kDot) {
graph_str = DumpUtil::ToDot(root_nodes);
} else if (format == GraphFormat::kHlo) {
graph_str = DumpUtil::ToHlo(root_values, unique_device
? *unique_device
: bridge::GetCurrentDevice());
graph_str = DumpUtil::ToHlo(root_values,
device.value_or(bridge::GetCurrentDevice()));
} else if (format == GraphFormat::kStableHlo) {
graph_str = DumpUtil::ToHlo(
root_values,
unique_device ? *unique_device : bridge::GetCurrentDevice(),
EmitMode::kStableHloReadable);
graph_str = DumpUtil::ToHlo(root_values,
device.value_or(bridge::GetCurrentDevice()),
EmitMode::kStableHloReadable);
} else {
XLA_ERROR() << "Invalid graph format: " << format;
}
Expand Down
9 changes: 0 additions & 9 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,6 @@ cc_library(
],
)

cc_library(
name = "unique",
hdrs = ["unique.h"],
deps = [
":debug_macros",
"@com_google_absl//absl/types:optional",
],
)

cc_library(
name = "util",
srcs = ["util.cc"],
Expand Down
50 changes: 0 additions & 50 deletions torch_xla/csrc/runtime/unique.h

This file was deleted.

2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <exception>
#include <functional>
#include <mutex>
#include <optional>
#include <set>
#include <stdexcept>
#include <unordered_set>
Expand All @@ -38,7 +39,6 @@
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/unique.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"
Expand Down
10 changes: 5 additions & 5 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <fstream>
#include <functional>
#include <mutex>
#include <optional>
#include <set>
#include <stdexcept>
#include <unordered_map>
Expand Down Expand Up @@ -47,7 +48,6 @@
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/unique.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
Expand Down Expand Up @@ -534,12 +534,12 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
const std::vector<XLATensorPtr>& tensors, const SyncTensorsConfig& config) {
tsl::profiler::TraceMe activity("CollectSyncTensors",
tsl::profiler::TraceMeLevel::kInfo);
runtime::util::Unique<torch::lazy::BackendDevice> unique_device;
std::optional<torch::lazy::BackendDevice> device;
for (size_t i = 0; i < tensors.size(); ++i) {
unique_device.set(tensors[i]->GetDevice());
device = tensors[i]->GetDevice();
}
SyncTensorCollection coll;
if (!unique_device) {
if (!device) {
return coll;
}

Expand All @@ -552,7 +552,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
// graph with on/off force_ltc_data should not match, hash wise.
coll.hash = torch::lazy::MHash(config.force_ltc_data);
coll.config = config;
coll.device = *unique_device;
coll.device = *device;
coll.indices.reserve(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&
Expand Down

0 comments on commit ebe4567

Please sign in to comment.