Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IsVirtualDeivce() and refactor #6726

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torch_xla/csrc/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ bool UseVirtualDevice(bool force_spmd) {
return use_virtual_device;
}

bool IsVirtualDevice(const std::string& device) {
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(ParseDeviceString(device).type());
return hw_type == XlaDeviceType::SPMD;
}

bool GetLockSpmdConfig() { return spmd_config_is_locked; }

bool CheckTpuDevice(XlaDeviceType hw_type) {
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ torch::lazy::BackendDevice GetVirtualDevice();
// Optionally, `force_spmd` to set `use_virtual_device` to true.
bool UseVirtualDevice(bool force_spmd = false);

// Return true if `device` is of SPMD device type.
bool IsVirtualDevice(const std::string& device);

// Return true if SPMD config can be switches. That is, no device has been
// initialized, yet.
bool GetLockSpmdConfig();
Expand Down
12 changes: 2 additions & 10 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1939,16 +1939,8 @@ void InitXlaModuleBindings(py::module m) {
std::vector<std::string> sharding_specs;
sharding_specs.reserve(tensors.size());
for (const at::Tensor& tensor : tensors) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLATensor::ShardingSpecPtr sharding_spec =
xtensor ? xtensor->sharding_spec() : nullptr;
if (sharding_spec != nullptr) {
sharding_specs.push_back(
xla::HloSharding::FromProto(sharding_spec->sharding)
->ToString());
} else {
sharding_specs.push_back("");
}
sharding_specs.push_back(
GetXLAShardingSpec(bridge::GetXlaTensor(tensor)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what will GetXLAShardingSpec return when sharding_spec is null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"" empty string as expected.

}
return sharding_specs;
});
Expand Down
4 changes: 1 addition & 3 deletions torch_xla/csrc/ops/device_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@ class DeviceData : public XlaNode {
// backend data with a partitioned one in the node operands. Note that
// this is permitted only if the node holds a placeholder.
void Assign(std::shared_ptr<torch::lazy::BackendData> data) {
// TODO(yeounoh) check if the existing data is a placeholder after we
// address the issue where some of the sync tensors spill with device node.
XLA_CHECK(data->shape() == data_->shape())
<< "Shape mismatch: expected (" << data_->shape().to_string()
<< "), actual (" << data->shape().to_string() << ")";
data_ = data;
data_.reset(data.get());
Comment on lines -44 to +42
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between these two?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better memory management, this is refactoring to reuse the existing data_ to hold the new data.

}

static DeviceData* Cast(const torch::lazy::Node* node);
Expand Down
11 changes: 7 additions & 4 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,10 +689,13 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
return {};
}

// We assume that caller can't mix virtual device and real device.
if (devices[0] == "SPMD:0") {
// When running in SPMD mode, tensors here in the unsharded
// CreateTensorsData should be implicitly replicated to all devices.
// CreateTensorsData should be implicitly replicated to all devices.
if (IsVirtualDevice(devices[0])) {
XLA_CHECK(
std::all_of(devices.begin(), devices.end(),
[&](const std::string& s) { return s == devices[0]; }))
<< "can't mix virtual device and real device.";

std::vector<std::string> local_devices =
runtime::GetComputationClient()->GetLocalDevices();
std::vector<runtime::ComputationClient::DataPtr> handles;
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1296,9 +1296,9 @@ XLAGraphExecutor::BuildInputOutputAliases(
}

XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
const std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices, const SyncTensorCollection& coll,
PostOrderData* po_data, const std::vector<torch::lazy::Value>& ir_values) {
std::vector<XLATensorPtr>& tensors, absl::Span<const std::string> devices,
const SyncTensorCollection& coll, PostOrderData* po_data,
const std::vector<torch::lazy::Value>& ir_values) {
tsl::profiler::TraceMe activity(
[&] {
return tsl::profiler::TraceMeEncode(
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
std::vector<size_t> SetBufferDonors(LoweringContext* lowering_ctx);

// We don't use upstream Compile to have BuildInputOutputAliases.
CompilationResult Compile(const std::vector<XLATensorPtr>& tensors,
// TODO(yeounoh) auto-sharding can change tensors shardings, which needs to be
// accounted for in Dynamo integration.
CompilationResult Compile(std::vector<XLATensorPtr>& tensors,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see how this will become problematic for the dynamo, since in dynamo we will first dry run the compilation and don't execute the compiled program. Please leave a TODO somewhere for the dynamo integration with auto-sharding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense I will also add a section in the design doc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// We don't use upstream Compile to have BuildInputOutputAliases.
  // TODO(yeounoh) auto-sharding can change tensors shardings, which needs to be
  // accounted for in Dynamo integration.
  CompilationResult Compile(std::vector<XLATensorPtr>& tensors,
                            absl::Span<const std::string> devices,
                            const SyncTensorCollection& coll,
                            PostOrderData* po_data,
                            const std::vector<torch::lazy::Value>& ir_values);

absl::Span<const std::string> devices,
const SyncTensorCollection& coll,
PostOrderData* po_data,
Expand Down
Loading