From c923e8f13b1d06f8d9661e0e1de5c2f6149f3505 Mon Sep 17 00:00:00 2001 From: Yeounoh Chung Date: Wed, 20 Mar 2024 13:42:39 -0700 Subject: [PATCH] [SPMD] auto-construct auto-sharding mesh ids (#6770) (#6782) --- WORKSPACE | 5 +- openxla_patches/quant_dequant_converter.diff | 122 ------------------- setup.py | 2 +- torch_xla/csrc/xla_graph_executor.cpp | 8 +- torch_xla/csrc/xla_sharding_util.cpp | 42 ++++++- torch_xla/csrc/xla_sharding_util.h | 5 +- 6 files changed, 47 insertions(+), 137 deletions(-) delete mode 100644 openxla_patches/quant_dequant_converter.diff diff --git a/WORKSPACE b/WORKSPACE index b30c5f591c7..a1c4a3256bb 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -49,11 +49,10 @@ http_archive( "//openxla_patches:cache_urls.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", - "//openxla_patches:quant_dequant_converter.diff", ], - strip_prefix = "xla-18cbd2019898d3a7b563aeb73683f0c5a6ce14fd", + strip_prefix = "xla-25c8a6781af6be51d3bc43a0953b07803ab761ea", urls = [ - "https://github.com/openxla/xla/archive/18cbd2019898d3a7b563aeb73683f0c5a6ce14fd.tar.gz", + "https://github.com/openxla/xla/archive/25c8a6781af6be51d3bc43a0953b07803ab761ea.tar.gz", ], ) diff --git a/openxla_patches/quant_dequant_converter.diff b/openxla_patches/quant_dequant_converter.diff deleted file mode 100644 index adaca99b9b5..00000000000 --- a/openxla_patches/quant_dequant_converter.diff +++ /dev/null @@ -1,122 +0,0 @@ -diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD -index a75baa9dad..a09fca4898 100644 ---- a/xla/translate/hlo_to_mhlo/BUILD -+++ b/xla/translate/hlo_to_mhlo/BUILD -@@ -40,6 +40,7 @@ cc_library( - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AsmParser", - "@llvm-project//mlir:IR", -+ "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:Support", - ], - ) -diff --git a/xla/translate/hlo_to_mhlo/custom_call_importer.cc b/xla/translate/hlo_to_mhlo/custom_call_importer.cc -index 9250747210..9cba4c992a 100644 ---- a/xla/translate/hlo_to_mhlo/custom_call_importer.cc -+++ b/xla/translate/hlo_to_mhlo/custom_call_importer.cc -@@ -21,6 +21,7 @@ limitations under the License. - #include "absl/strings/match.h" - #include "llvm/ADT/STLExtras.h" - #include "mlir/AsmParser/AsmParser.h" // from @llvm-project -+#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project - #include "mlir/IR/Builders.h" // from @llvm-project - #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project - #include "mlir/IR/Location.h" // from @llvm-project -@@ -93,6 +94,66 @@ absl::StatusOr ImportRealDynamicSliceOp( - - } // namespace - -+mlir::Type getQuantizedType(mlir::DictionaryAttr& backend_config) { -+ std::vector scales; -+ std::vector zero_points; -+ int64_t quantization_dimension = -1, storage_max = 0, storage_min = 0; -+ mlir::Type storage_type, expressed_type; -+ -+ auto scales_attr = backend_config.get("scale"); -+ if (scales_attr) { -+ for (auto scale_attr : scales_attr.cast()) { -+ scales.push_back(scale_attr.cast().getValueAsDouble()); -+ } -+ } -+ -+ auto zero_points_attr = backend_config.get("zero_point"); -+ if (zero_points_attr) { -+ for (auto zero_point_attr : zero_points_attr.cast()) { -+ zero_points.push_back(zero_point_attr.cast().getInt()); -+ } -+ } -+ -+ auto quantization_dimension_attr = -+ backend_config.get("quantization_dimension"); -+ if (quantization_dimension_attr) { -+ quantization_dimension = -+ quantization_dimension_attr.cast().getInt(); -+ } -+ -+ auto storage_max_attr = backend_config.get("storage_max"); -+ if (storage_max_attr) { -+ storage_max = storage_max_attr.cast().getInt(); -+ } -+ -+ auto storage_min_attr = backend_config.get("storage_min"); -+ if (storage_min_attr) { -+ storage_min = storage_min_attr.cast().getInt(); -+ } -+ -+ auto storage_type_attr = backend_config.get("storage_type"); -+ if (storage_type_attr) { -+ storage_type = storage_type_attr.cast().getValue(); -+ } -+ -+ auto expressed_type_attr = backend_config.get("expressed_type"); -+ if (expressed_type_attr) { -+ expressed_type = expressed_type_attr.cast().getValue(); -+ } -+ -+ auto is_signed = storage_type.cast().isSignless(); -+ -+ if (quantization_dimension != -1) { -+ return mlir::quant::UniformQuantizedPerAxisType::get( -+ is_signed, storage_type, expressed_type, scales, zero_points, -+ quantization_dimension, storage_min, storage_max); -+ } else { -+ return mlir::quant::UniformQuantizedType::get( -+ is_signed, storage_type, expressed_type, scales[0], zero_points[0], -+ storage_min, storage_max); -+ } -+} -+ - absl::StatusOr ImportCustomCallAsOp( - const HloCustomCallInstruction* instruction, mlir::Location loc, - mlir::Type result_type, mlir::ValueRange operands, -@@ -112,6 +173,30 @@ absl::StatusOr ImportCustomCallAsOp( - return ImportRealDynamicSliceOp(backend_config_str, loc, result_type, - operands, builder); - } -+ -+ auto attr = mlir::parseAttribute(backend_config_str, builder->getContext()) -+ .dyn_cast(); -+ if (!attr) { -+ return Internal( -+ "Couldn't parse backend config into a dictionary attribute"); -+ } -+ auto backend_config = attr.cast(); -+ if (custom_call_target == "mhlo.uniform_quantize") { -+ return builder -+ ->create( -+ loc, -+ mlir::RankedTensorType::get( -+ result_type.cast().getShape(), -+ getQuantizedType(backend_config)), -+ operands) -+ .getOperation(); -+ } -+ -+ if (custom_call_target == "mhlo.uniform_dequantize") { -+ return builder -+ ->create(loc, result_type, operands) -+ .getOperation(); -+ } - return InvalidArgument("Unsupported MHLO op custom_call %s", - custom_call_target); - } diff --git a/setup.py b/setup.py index 249dbb307c1..16f14494f74 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240305' +_date = '20240320' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' _jax_version = f'0.4.26.dev{_date}' diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index d5a2f1c72f4..39e036a1356 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1402,9 +1402,11 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( // Apply XLA_AUTO_SPMD_MESH if it is set. // TODO(yeounoh) allow multi mesh exploration. - auto mesh_shape_ids = ShardingUtil::GetAutoShardingMesh(); - std::vector auto_spmd_mesh_shape = std::get<0>(mesh_shape_ids); - std::vector auto_spmd_mesh_ids = std::get<1>(mesh_shape_ids); + std::vector auto_spmd_mesh_shape = + ShardingUtil::GetAutoShardingMesh(); + std::vector auto_spmd_mesh_ids = + ShardingUtil::GetAutoShardingMeshIds( + instances.front().computation.proto()); instances.front().auto_spmd_mesh_shape = auto_spmd_mesh_shape; instances.front().auto_spmd_mesh_ids = auto_spmd_mesh_ids; TF_VLOG(5) << "auto_spmd_mesh_shape={" diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 7f8b96254ff..595d3862012 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -624,14 +624,12 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( source_tensors, GetVirtualDevice().toString(), global_shape, sharding); } -std::tuple, std::vector> -ShardingUtil::GetAutoShardingMesh() { +std::vector ShardingUtil::GetAutoShardingMesh() { // Auto-sharding uses mesh_shape = {n_devices, 1} if XLA_AUTO_SPMD_MESH // is not set. XLA_AUTO_SPMD_MESH takes a form of string, "2,2" which // corresponds to a 2-by-2 mesh. std::vector mesh_shape = ParseStringToIntVector( runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", "")); - std::vector device_mesh_ids; if (!mesh_shape.empty()) { int64_t total_devices = 1; for (auto i : mesh_shape) { @@ -641,10 +639,42 @@ ShardingUtil::GetAutoShardingMesh() { runtime::GetComputationClient()->GetAllDevices().size()) << "Invalid auto-sharding mesh_shape: " << absl::StrJoin(mesh_shape, ","); - device_mesh_ids = std::vector(total_devices); - std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0); } - return std::make_tuple(mesh_shape, device_mesh_ids); + return mesh_shape; +} + +std::vector ShardingUtil::GetAutoShardingMeshIds( + const xla::HloModuleProto& module) { + // Return the first non-default (iota) mesh ids arrangement, as we expect + // only one such assignment and/or the logical mesh device assignment should + // be compatible with the other arrangements in the HLO. This is a work-around + // as the auto-sharding pass takes only one arrangement for now. + // TODO(yeounoh) this was not necessary before; replace if this can be done + // during the auto-sharding pass. + int64_t n_devices = runtime::GetComputationClient()->GetAllDevices().size(); + std::vector device_mesh_ids = std::vector(n_devices); + std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0); + + // Unforuntately, we have to go through the instructions since + // `spmd_parameters_shardings` is not available. + for (auto computation : module.computations()) { + for (auto instruction : computation.instructions()) { + if (instruction.opcode() == "parameter" && instruction.has_sharding()) { + xla::OpSharding sharding = instruction.sharding(); + auto tile_assignment_devices = sharding.tile_assignment_devices(); + if (!tile_assignment_devices.empty()) { + auto new_mesh_ids = std::vector( + tile_assignment_devices.begin(), tile_assignment_devices.end()); + // return the first non-default (iota) device assigments. + if (new_mesh_ids != device_mesh_ids) { + return new_mesh_ids; + } + } + } + } + } + // return the default (iota) device assignments. + return device_mesh_ids; } void ShardingUtil::ReshardParameters( diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index d25aee9e4a2..5e0a414b00c 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -126,8 +126,9 @@ class ShardingUtil { // Construct a device mesh for auto-sharding pass. Returns a tuple of mesh // shape and device ids vectors. - static std::tuple, std::vector> - GetAutoShardingMesh(); + static std::vector GetAutoShardingMesh(); + static std::vector GetAutoShardingMeshIds( + const xla::HloModuleProto& module); // Reshard the parameters if the expected shardings mismatch. Resharding is // expensive especially for those already sharded. The cost can easily be