diff --git a/.bazelversion b/.bazelversion index 03f488b076a..f22d756da39 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -5.3.0 +6.5.0 diff --git a/.circleci/common.sh b/.circleci/common.sh index c16ab238948..96ec58ff29e 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -144,10 +144,13 @@ function run_torch_xla_python_tests() { PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=$num_devices test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18 # single-host-SPMD - XLA_USE_SPMD=1 PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=1 test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 16 --model=resnet50 --sharding=batch --num_epochs=1 --num_steps=25 --model=resnet18 + # TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677) + XLA_USE_SPMD=1 PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=1 test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 8 --model=resnet50 --sharding=batch --num_epochs=1 --num_steps=25 --model=resnet18 - PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 - PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 + # TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677) + PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 --batch_size 32 --test_set_batch_size 32 + # TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677) + PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 --batch_size 32 --test_set_batch_size 32 XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 # Syncfree SGD optimizer tests if [ -d ./torch_xla/amp/syncfree ]; then diff --git a/WORKSPACE b/WORKSPACE index b55bbf30b79..b30c5f591c7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -51,9 +51,9 @@ http_archive( "//openxla_patches:f16_abi_clang.diff", "//openxla_patches:quant_dequant_converter.diff", ], - strip_prefix = "xla-419a3d736bdfe6891bc40e9ab5c78b466c8e0dc6", + strip_prefix = "xla-18cbd2019898d3a7b563aeb73683f0c5a6ce14fd", urls = [ - "https://github.com/openxla/xla/archive/419a3d736bdfe6891bc40e9ab5c78b466c8e0dc6.tar.gz", + "https://github.com/openxla/xla/archive/18cbd2019898d3a7b563aeb73683f0c5a6ce14fd.tar.gz", ], ) diff --git a/openxla_patches/quant_dequant_converter.diff b/openxla_patches/quant_dequant_converter.diff index 91aae488afc..adaca99b9b5 100644 --- a/openxla_patches/quant_dequant_converter.diff +++ b/openxla_patches/quant_dequant_converter.diff @@ -1,31 +1,36 @@ -// TODO(lsy323): This is a patch on the HLO->StableHLO converter, this allows the custom call to -// stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize. -// The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO. diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD -index cd6e427c3..9deecdcb2 100644 +index a75baa9dad..a09fca4898 100644 --- a/xla/translate/hlo_to_mhlo/BUILD +++ b/xla/translate/hlo_to_mhlo/BUILD -@@ -68,6 +68,7 @@ cc_library( - "@llvm-project//mlir:ArithDialect", +@@ -40,6 +40,7 @@ cc_library( + "@llvm-project//llvm:Support", "@llvm-project//mlir:AsmParser", - "@llvm-project//mlir:FuncDialect", -+ "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:IR", - "@llvm-project//mlir:SparseTensorDialect", - "@tsl//tsl/platform:statusor", -diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc -index 5e5e32652..d246383bd 100644 ---- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc -+++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc -@@ -669,6 +669,71 @@ StatusOr HloFunctionImporter::ImportInstruction( - return importer.ImportInstructionWithLayout(instr, operands, builder, mode); - } - -+Type getQuantizedType(mlir::DictionaryAttr& backend_config) { ++ "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], + ) +diff --git a/xla/translate/hlo_to_mhlo/custom_call_importer.cc b/xla/translate/hlo_to_mhlo/custom_call_importer.cc +index 9250747210..9cba4c992a 100644 +--- a/xla/translate/hlo_to_mhlo/custom_call_importer.cc ++++ b/xla/translate/hlo_to_mhlo/custom_call_importer.cc +@@ -21,6 +21,7 @@ limitations under the License. + #include "absl/strings/match.h" + #include "llvm/ADT/STLExtras.h" + #include "mlir/AsmParser/AsmParser.h" // from @llvm-project ++#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project + #include "mlir/IR/Builders.h" // from @llvm-project + #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project + #include "mlir/IR/Location.h" // from @llvm-project +@@ -93,6 +94,66 @@ absl::StatusOr ImportRealDynamicSliceOp( + + } // namespace + ++mlir::Type getQuantizedType(mlir::DictionaryAttr& backend_config) { + std::vector scales; + std::vector zero_points; + int64_t quantization_dimension = -1, storage_max = 0, storage_min = 0; -+ Type storage_type, expressed_type; ++ mlir::Type storage_type, expressed_type; + + auto scales_attr = backend_config.get("scale"); + if (scales_attr) { @@ -61,15 +66,11 @@ index 5e5e32652..d246383bd 100644 + auto storage_type_attr = backend_config.get("storage_type"); + if (storage_type_attr) { + storage_type = storage_type_attr.cast().getValue(); -+ //.cast() -+ //.getElementType(); + } + + auto expressed_type_attr = backend_config.get("expressed_type"); + if (expressed_type_attr) { + expressed_type = expressed_type_attr.cast().getValue(); -+ //.cast() -+ //.getElementType(); + } + + auto is_signed = storage_type.cast().isSignless(); @@ -85,54 +86,37 @@ index 5e5e32652..d246383bd 100644 + } +} + + absl::StatusOr ImportCustomCallAsOp( + const HloCustomCallInstruction* instruction, mlir::Location loc, + mlir::Type result_type, mlir::ValueRange operands, +@@ -112,6 +173,30 @@ absl::StatusOr ImportCustomCallAsOp( + return ImportRealDynamicSliceOp(backend_config_str, loc, result_type, + operands, builder); + } + - StatusOr HloFunctionImporter::ImportCustomCallAsOp( - const HloInstruction* instruction, mlir::Location loc, - const Type result_type, mlir::ValueRange operands, -@@ -992,6 +1057,25 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( - "Couldn't parse backend config into a dictionary attribute"); - - attributes.push_back(builder_->getNamedAttr("backend_config", attr)); -+ auto backend_config = attr.cast(); -+ if (custom_call->custom_call_target() == -+ "stablehlo.uniform_quantize") { -+ return func_builder -+ ->create( -+ loc, -+ mlir::RankedTensorType::get( -+ result_type.cast().getShape(), -+ getQuantizedType(backend_config)), -+ operands) -+ .getOperation(); -+ } ++ auto attr = mlir::parseAttribute(backend_config_str, builder->getContext()) ++ .dyn_cast(); ++ if (!attr) { ++ return Internal( ++ "Couldn't parse backend config into a dictionary attribute"); ++ } ++ auto backend_config = attr.cast(); ++ if (custom_call_target == "mhlo.uniform_quantize") { ++ return builder ++ ->create( ++ loc, ++ mlir::RankedTensorType::get( ++ result_type.cast().getShape(), ++ getQuantizedType(backend_config)), ++ operands) ++ .getOperation(); ++ } + -+ if (custom_call->custom_call_target() == -+ "stablehlo.uniform_dequantize") { -+ return func_builder -+ ->create( -+ loc, result_type, operands) .getOperation(); -+ } - } - } else { - attributes.push_back(builder_->getNamedAttr( -diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc -index f9edd1272..23a747fb1 100644 ---- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc -+++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc -@@ -19,6 +19,8 @@ limitations under the License. - #include - #include - -+#include "mlir/Dialect/Quant/QuantOps.h" -+#include "mlir/Dialect/Quant/QuantTypes.h" - #include "mlir/IR/Attributes.h" // from @llvm-project - #include "xla/hlo/ir/hlo_computation.h" - #include "xla/hlo/ir/hlo_instruction.h" -@@ -41,6 +43,7 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module, - module.getContext()->loadDialect(); - module.getContext()->loadDialect(); - module.getContext()->loadDialect(); -+ module.getContext()->loadDialect(); ++ if (custom_call_target == "mhlo.uniform_dequantize") { ++ return builder ++ ->create(loc, result_type, operands) ++ .getOperation(); ++ } + return InvalidArgument("Unsupported MHLO op custom_call %s", + custom_call_target); } - - namespace { diff --git a/setup.py b/setup.py index be095744524..4d26e1a2460 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240213' +_date = '20240305' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' _jax_version = f'0.4.25.dev{_date}' diff --git a/torch_xla/csrc/ops/dequant_tensor.cpp b/torch_xla/csrc/ops/dequant_tensor.cpp index bfafcb34e23..dd074ad2630 100644 --- a/torch_xla/csrc/ops/dequant_tensor.cpp +++ b/torch_xla/csrc/ops/dequant_tensor.cpp @@ -39,7 +39,7 @@ XlaOpVector DequantizeTensor::Lower(LoweringContext* loctx) const { xla::Shape output_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, input_shape.dimensions()); // TODO(lsy323): Lower to HLO directly once qdtype is added to HLO. - static const std::string opname = "stablehlo.uniform_dequantize"; + static const std::string opname = "mhlo.uniform_dequantize"; auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_, dtype_, xla::PrimitiveType::F32); diff --git a/torch_xla/csrc/ops/quant_tensor.cpp b/torch_xla/csrc/ops/quant_tensor.cpp index dcd1da53596..4845168be34 100644 --- a/torch_xla/csrc/ops/quant_tensor.cpp +++ b/torch_xla/csrc/ops/quant_tensor.cpp @@ -40,7 +40,7 @@ XlaOpVector QuantizeTensor::Lower(LoweringContext* loctx) const { GetTorchIntDtypeToHloDtype(dtype_), input_shape.dimensions()); // TODO(lsy323): Lower to HLO directly once qdtype is added to HLO. - static const std::string opname = "stablehlo.uniform_quantize"; + static const std::string opname = "mhlo.uniform_quantize"; auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_, dtype_, input_shape.element_type()); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 57db481707d..7fc531d0086 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -56,7 +56,9 @@ std::unordered_map build_index_map( xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { xla::Shape shape = xla::ShapeUtil::MakeShape( buffer->element_type(), buffer->logical_dimensions().value()); - *shape.mutable_layout() = buffer->layout(); + *shape.mutable_layout() = + dynamic_cast(buffer->layout().get()) + ->xla_layout(); return xla::ShapeUtil::DeviceShapeToHostShape(shape); } diff --git a/torch_xla/csrc/runtime/profiler.cc b/torch_xla/csrc/runtime/profiler.cc index c9c66c0e1b9..9700ab7abd6 100644 --- a/torch_xla/csrc/runtime/profiler.cc +++ b/torch_xla/csrc/runtime/profiler.cc @@ -17,10 +17,10 @@ namespace profiler { namespace { const PLUGIN_Profiler_Api* FindProfilerApi(const PJRT_Api* pjrt_api) { - const PJRT_Structure_Base* next = - reinterpret_cast(pjrt_api->extension_start); + const PJRT_Extension_Base* next = + reinterpret_cast(pjrt_api->extension_start); while (next != nullptr && - next->type != PJRT_Structure_Type::PJRT_Structure_Type_Profiler) { + next->type != PJRT_Extension_Type::PJRT_Extension_Type_Profiler) { next = next->next; } if (next == nullptr) {