From 11cc7266857beaab538645db5f91222aab9593d6 Mon Sep 17 00:00:00 2001 From: Yeounoh Chung Date: Tue, 12 Mar 2024 13:51:12 -0700 Subject: [PATCH] Add IsVirtualDeivce() and refactor to remove redundant codes --- torch_xla/csrc/device.cpp | 6 ++++++ torch_xla/csrc/device.h | 3 +++ torch_xla/csrc/init_python_bindings.cpp | 12 ++---------- torch_xla/csrc/ops/device_data.h | 4 +--- torch_xla/csrc/tensor_util.cpp | 11 +++++++---- torch_xla/csrc/xla_graph_executor.cpp | 6 +++--- torch_xla/csrc/xla_graph_executor.h | 2 +- 7 files changed, 23 insertions(+), 21 deletions(-) diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index 443023009a21..16c770019043 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -91,6 +91,12 @@ bool UseVirtualDevice(bool force_spmd) { return use_virtual_device; } +bool IsVirtualDevice(const std::string& device) { + XlaDeviceType hw_type = + static_cast(ParseDeviceString(device).type()); + return hw_type == XlaDeviceType::SPMD; +} + bool GetLockSpmdConfig() { return spmd_config_is_locked; } bool CheckTpuDevice(XlaDeviceType hw_type) { diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index 7161b03eb6ba..84fdf7f0b164 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -45,6 +45,9 @@ torch::lazy::BackendDevice GetVirtualDevice(); // Optionally, `force_spmd` to set `use_virtual_device` to true. bool UseVirtualDevice(bool force_spmd = false); +// Return true if `device` is of SPMD device type. +bool IsVirtualDevice(const std::string& device); + // Return true if SPMD config can be switches. That is, no device has been // initialized, yet. bool GetLockSpmdConfig(); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 86a08238ec8f..c66bfe915b0b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1939,16 +1939,8 @@ void InitXlaModuleBindings(py::module m) { std::vector sharding_specs; sharding_specs.reserve(tensors.size()); for (const at::Tensor& tensor : tensors) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); - XLATensor::ShardingSpecPtr sharding_spec = - xtensor ? xtensor->sharding_spec() : nullptr; - if (sharding_spec != nullptr) { - sharding_specs.push_back( - xla::HloSharding::FromProto(sharding_spec->sharding) - ->ToString()); - } else { - sharding_specs.push_back(""); - } + sharding_specs.push_back( + GetXLAShardingSpec(bridge::GetXlaTensor(tensor))); } return sharding_specs; }); diff --git a/torch_xla/csrc/ops/device_data.h b/torch_xla/csrc/ops/device_data.h index dc70924be3da..217f2cef197d 100644 --- a/torch_xla/csrc/ops/device_data.h +++ b/torch_xla/csrc/ops/device_data.h @@ -36,12 +36,10 @@ class DeviceData : public XlaNode { // backend data with a partitioned one in the node operands. Note that // this is permitted only if the node holds a placeholder. void Assign(std::shared_ptr data) { - // TODO(yeounoh) check if the existing data is a placeholder after we - // address the issue where some of the sync tensors spill with device node. XLA_CHECK(data->shape() == data_->shape()) << "Shape mismatch: expected (" << data_->shape().to_string() << "), actual (" << data->shape().to_string() << ")"; - data_ = data; + data_.reset(data.get()); } static DeviceData* Cast(const torch::lazy::Node* node); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 9c00e212ccc1..db9d83907243 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -689,10 +689,13 @@ std::vector CreateTensorsData( return {}; } - // We assume that caller can't mix virtual device and real device. - if (devices[0] == "SPMD:0") { - // When running in SPMD mode, tensors here in the unsharded - // CreateTensorsData should be implicitly replicated to all devices. + // CreateTensorsData should be implicitly replicated to all devices. + if (IsVirtualDevice(devices[0])) { + XLA_CHECK( + std::all_of(devices.begin(), devices.end(), + [&](const std::string& s) { return s == devices[0]; })) + << "can't mix virtual device and real device."; + std::vector local_devices = runtime::GetComputationClient()->GetLocalDevices(); std::vector handles; diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 22e747b1362e..dd5e3faff590 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1279,9 +1279,9 @@ XLAGraphExecutor::BuildInputOutputAliases( } XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( - const std::vector& tensors, - absl::Span devices, const SyncTensorCollection& coll, - PostOrderData* po_data, const std::vector& ir_values) { + std::vector& tensors, absl::Span devices, + const SyncTensorCollection& coll, PostOrderData* po_data, + const std::vector& ir_values) { tsl::profiler::TraceMe activity( [&] { return tsl::profiler::TraceMeEncode( diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 4d7d29c1b147..76ab0ef5f055 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -346,7 +346,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { std::vector SetBufferDonors(LoweringContext* lowering_ctx); // We don't use upstream Compile to have BuildInputOutputAliases. - CompilationResult Compile(const std::vector& tensors, + CompilationResult Compile(std::vector& tensors, absl::Span devices, const SyncTensorCollection& coll, PostOrderData* po_data,