Skip to content

Commit

Permalink
Add IsVirtualDeivce() and refactor to remove redundant codes
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Mar 12, 2024
1 parent bf60b62 commit 11cc726
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 21 deletions.
6 changes: 6 additions & 0 deletions torch_xla/csrc/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ bool UseVirtualDevice(bool force_spmd) {
return use_virtual_device;
}

bool IsVirtualDevice(const std::string& device) {
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(ParseDeviceString(device).type());
return hw_type == XlaDeviceType::SPMD;
}

bool GetLockSpmdConfig() { return spmd_config_is_locked; }

bool CheckTpuDevice(XlaDeviceType hw_type) {
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ torch::lazy::BackendDevice GetVirtualDevice();
// Optionally, `force_spmd` to set `use_virtual_device` to true.
bool UseVirtualDevice(bool force_spmd = false);

// Return true if `device` is of SPMD device type.
bool IsVirtualDevice(const std::string& device);

// Return true if SPMD config can be switches. That is, no device has been
// initialized, yet.
bool GetLockSpmdConfig();
Expand Down
12 changes: 2 additions & 10 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1939,16 +1939,8 @@ void InitXlaModuleBindings(py::module m) {
std::vector<std::string> sharding_specs;
sharding_specs.reserve(tensors.size());
for (const at::Tensor& tensor : tensors) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLATensor::ShardingSpecPtr sharding_spec =
xtensor ? xtensor->sharding_spec() : nullptr;
if (sharding_spec != nullptr) {
sharding_specs.push_back(
xla::HloSharding::FromProto(sharding_spec->sharding)
->ToString());
} else {
sharding_specs.push_back("");
}
sharding_specs.push_back(
GetXLAShardingSpec(bridge::GetXlaTensor(tensor)));
}
return sharding_specs;
});
Expand Down
4 changes: 1 addition & 3 deletions torch_xla/csrc/ops/device_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@ class DeviceData : public XlaNode {
// backend data with a partitioned one in the node operands. Note that
// this is permitted only if the node holds a placeholder.
void Assign(std::shared_ptr<torch::lazy::BackendData> data) {
// TODO(yeounoh) check if the existing data is a placeholder after we
// address the issue where some of the sync tensors spill with device node.
XLA_CHECK(data->shape() == data_->shape())
<< "Shape mismatch: expected (" << data_->shape().to_string()
<< "), actual (" << data->shape().to_string() << ")";
data_ = data;
data_.reset(data.get());
}

static DeviceData* Cast(const torch::lazy::Node* node);
Expand Down
11 changes: 7 additions & 4 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,10 +689,13 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
return {};
}

// We assume that caller can't mix virtual device and real device.
if (devices[0] == "SPMD:0") {
// When running in SPMD mode, tensors here in the unsharded
// CreateTensorsData should be implicitly replicated to all devices.
// CreateTensorsData should be implicitly replicated to all devices.
if (IsVirtualDevice(devices[0])) {
XLA_CHECK(
std::all_of(devices.begin(), devices.end(),
[&](const std::string& s) { return s == devices[0]; }))
<< "can't mix virtual device and real device.";

std::vector<std::string> local_devices =
runtime::GetComputationClient()->GetLocalDevices();
std::vector<runtime::ComputationClient::DataPtr> handles;
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,9 +1279,9 @@ XLAGraphExecutor::BuildInputOutputAliases(
}

XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
const std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices, const SyncTensorCollection& coll,
PostOrderData* po_data, const std::vector<torch::lazy::Value>& ir_values) {
std::vector<XLATensorPtr>& tensors, absl::Span<const std::string> devices,
const SyncTensorCollection& coll, PostOrderData* po_data,
const std::vector<torch::lazy::Value>& ir_values) {
tsl::profiler::TraceMe activity(
[&] {
return tsl::profiler::TraceMeEncode(
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
std::vector<size_t> SetBufferDonors(LoweringContext* lowering_ctx);

// We don't use upstream Compile to have BuildInputOutputAliases.
CompilationResult Compile(const std::vector<XLATensorPtr>& tensors,
CompilationResult Compile(std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices,
const SyncTensorCollection& coll,
PostOrderData* po_data,
Expand Down

0 comments on commit 11cc726

Please sign in to comment.