diff --git a/WORKSPACE b/WORKSPACE index ace663554168..2d3e69033fda 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -43,6 +43,8 @@ http_archive( "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", "//openxla_patches:gpu_topk_rewriter.diff", + "//openxla_patches:quant_dequant_converter.diff", + "//openxla_patches:stablehlo_quant_seralization.diff", ], strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478", urls = [ diff --git a/openxla_patches/quant_dequant_converter.diff b/openxla_patches/quant_dequant_converter.diff new file mode 100644 index 000000000000..d35e36a5e22b --- /dev/null +++ b/openxla_patches/quant_dequant_converter.diff @@ -0,0 +1,137 @@ +// TODO(lsy323): This is a patch on the HLO->StableHLO converter, this allows the custom call to +// stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize. +// The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO. +diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD +index f74973ae1..8e3f0e06b 100644 +--- a/xla/translate/hlo_to_mhlo/BUILD ++++ b/xla/translate/hlo_to_mhlo/BUILD +@@ -67,6 +67,7 @@ cc_library( + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SparseTensorDialect", + "@tsl//tsl/platform:statusor", +diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +index 08d5f49c8..2f9ad1e0b 100644 +--- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc ++++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +@@ -664,6 +664,70 @@ StatusOr HloFunctionImporter::ImportInstruction( + return importer.ImportInstructionWithLayout(instr, operands, builder, mode); + } + ++Type getQuantizedType(mlir::DictionaryAttr& backend_config) { ++ std::vector scales; ++ std::vector zero_points; ++ int64_t quantization_dimension = -1, storage_max = 0, storage_min = 0; ++ Type storage_type, expressed_type; ++ ++ auto scales_attr = backend_config.get("scale"); ++ if (scales_attr) { ++ for (auto scale_attr : scales_attr.cast()) { ++ scales.push_back(scale_attr.cast().getValueAsDouble()); ++ } ++ } ++ ++ auto zero_points_attr = backend_config.get("zero_point"); ++ if (zero_points_attr) { ++ for (auto zero_point_attr : zero_points_attr.cast()) { ++ zero_points.push_back(zero_point_attr.cast().getInt()); ++ } ++ } ++ ++ auto quantization_dimension_attr = ++ backend_config.get("quantization_dimension"); ++ if (quantization_dimension_attr) { ++ quantization_dimension = ++ quantization_dimension_attr.cast().getInt(); ++ } ++ ++ auto storage_max_attr = backend_config.get("storage_max"); ++ if (storage_max_attr) { ++ storage_max = storage_max_attr.cast().getInt(); ++ } ++ ++ auto storage_min_attr = backend_config.get("storage_min"); ++ if (storage_min_attr) { ++ storage_min = storage_min_attr.cast().getInt(); ++ } ++ ++ auto storage_type_attr = backend_config.get("storage_type"); ++ if (storage_type_attr) { ++ storage_type = storage_type_attr.cast().getValue(); ++ //.cast() ++ //.getElementType(); ++ } ++ ++ auto expressed_type_attr = backend_config.get("expressed_type"); ++ if (expressed_type_attr) { ++ expressed_type = expressed_type_attr.cast().getValue(); ++ //.cast() ++ //.getElementType(); ++ } ++ ++ auto is_signed = storage_type.cast().isSigned(); ++ ++ if (quantization_dimension != -1) { ++ return mlir::quant::UniformQuantizedPerAxisType::get( ++ is_signed, storage_type, expressed_type, scales, zero_points, ++ quantization_dimension, storage_min, storage_max); ++ } else { ++ return mlir::quant::UniformQuantizedType::get( ++ is_signed, storage_type, expressed_type, scales[0], zero_points[0], ++ storage_min, storage_max); ++ } ++} ++ + StatusOr HloFunctionImporter::ImportInstructionImpl( + const HloInstruction* instruction, + const llvm::SmallVectorImpl& operands, +@@ -933,6 +997,25 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( + "Couldn't parse backend config into a dictionary attribute"); + + attributes.push_back(builder_->getNamedAttr("backend_config", attr)); ++ auto backend_config = attr.cast(); ++ if (custom_call->custom_call_target() == ++ "stablehlo.uniform_quantize") { ++ return func_builder ++ ->create( ++ loc, ++ mlir::RankedTensorType::get( ++ result_type.cast().getShape(), ++ getQuantizedType(backend_config)), ++ operands) ++ .getOperation(); ++ } ++ ++ if (custom_call->custom_call_target() == ++ "stablehlo.uniform_dequantize") { ++ return func_builder ++ ->create( ++ loc, result_type, operands) .getOperation(); ++ } + } + } else { + attributes.push_back(builder_->getNamedAttr( +diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +index 9f05992c8..03cf4840d 100644 +--- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc ++++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +@@ -19,6 +19,8 @@ limitations under the License. + #include + #include + ++#include "mlir/Dialect/Quant/QuantOps.h" ++#include "mlir/Dialect/Quant/QuantTypes.h" + #include "mlir/IR/Attributes.h" // from @llvm-project + #include "xla/hlo/ir/hlo_computation.h" + #include "xla/hlo/ir/hlo_instruction.h" +@@ -41,6 +43,7 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module, + module.getContext()->loadDialect(); + module.getContext()->loadDialect(); + module.getContext()->loadDialect(); ++ module.getContext()->loadDialect(); + } + + namespace { diff --git a/openxla_patches/stablehlo_quant_seralization.diff b/openxla_patches/stablehlo_quant_seralization.diff new file mode 100644 index 000000000000..fc4328dcfa79 --- /dev/null +++ b/openxla_patches/stablehlo_quant_seralization.diff @@ -0,0 +1,45 @@ +// TODO(lsy323): This patch is needed to serialize stablehlo.uniform_quantize/dequantize in bytecode format +// This patch can be removed after https://github.com/openxla/stablehlo/issues/1812 is fixed. +diff --git a/third_party/stablehlo/stablehlo_quant_seralization.patch b/third_party/stablehlo/stablehlo_quant_seralization.patch +new file mode 100644 +index 000000000..24e23b67d +--- /dev/null ++++ b/third_party/stablehlo/stablehlo_quant_seralization.patch +@@ -0,0 +1,26 @@ ++diff --git a/stablehlo/api/PortableApi.cpp b/stablehlo/api/PortableApi.cpp ++index 07c856db..cd169cae 100644 ++--- a/stablehlo/api/PortableApi.cpp +++++ b/stablehlo/api/PortableApi.cpp ++@@ -15,10 +15,13 @@ limitations under the License. ++ ++ #include "stablehlo/api/PortableApi.h" ++ +++#include ++ #include ++ ++ #include "mlir/Bytecode/BytecodeWriter.h" ++ #include "mlir/Dialect/Func/IR/FuncOps.h" +++#include "mlir/Dialect/Quant/QuantOps.h" +++#include "mlir/Dialect/Quant/QuantTypes.h" ++ #include "mlir/IR/MLIRContext.h" ++ #include "mlir/Parser/Parser.h" ++ #include "stablehlo/dialect/Serialization.h" ++@@ -33,6 +36,7 @@ void loadSerializationDialects(MLIRContext* context) { ++ context->loadDialect(); ++ context->loadDialect(); ++ context->loadDialect(); +++ context->loadDialect(); ++ } ++ } // namespace ++ +diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl +index 9f4494aac..64fa072bb 100644 +--- a/third_party/stablehlo/workspace.bzl ++++ b/third_party/stablehlo/workspace.bzl +@@ -15,5 +15,6 @@ def repo(): + urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)), + patch_file = [ + "//third_party/stablehlo:temporary.patch", # Autogenerated, don't remove. ++ "//third_party/stablehlo:stablehlo_quant_seralization.patch", # Load quant dialect. + ], + ) diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py new file mode 100644 index 000000000000..39e0442e2f46 --- /dev/null +++ b/test/stablehlo/test_pt2e_qdq.py @@ -0,0 +1,117 @@ +import os +import unittest +from typing import Callable, Dict, List + +import torch +import torch_xla.core.xla_model as xm +import torchvision +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, get_symmetric_quantization_config) +from torch_xla import stablehlo +from torch_xla.tf_saved_model_integration import \ + save_torch_module_as_tf_saved_model + +# Needed to workaround the stablehlo bytecode serialization issue in https://github.com/openxla/stablehlo/issues/1812 +os.environ['STABLEHLO_BYTECODE_FROM_PRETTYPRINT'] = '1' + +_TORCH_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + +_TORCH_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def count_fx_graph_nodes(g: torch.fx.Graph, op_dict: Dict[str, List[Callable]]): + cnt = {} + for k in op_dict.keys(): + cnt[k] = 0 + for node in g.nodes: + for k, ops in op_dict.items(): + if node.target in ops: + cnt[k] += 1 + return cnt + + +def count_qdq_ops(g: torch.fx.Graph): + op_dict = { + "qunatize": _TORCH_QUANTIZE_OPS, + "dequantize": _TORCH_DEQUANTIZE_OPS, + } + return count_fx_graph_nodes(g, op_dict) + + +class PT2EExportTest(unittest.TestCase): + + def test_per_tensor_qdq(self): + device = xm.xla_device() + x = torch.randn(2, 3, 4, 5).to(device) + x = torch.ops.quantized_decomposed.quantize_per_tensor( + x, 0.4, 2, -128, 127, torch.int8) + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 0.4, 2, -128, 127, torch.int8) + stablehlo_txt = xm.get_stablehlo([x]) + self.assertEqual( + stablehlo_txt.count( + 'tensor<2x3x4x5x!quant.uniform>'), 2) + self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1) + self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) + + def test_per_channel_qdq(self): + device = xm.xla_device() + x = torch.randn(2, 3, 4, 5).to(device) + scale = torch.tensor([3.2, 5.3, 0.1, 10]) + zero_point = torch.tensor([1, 2, -1, -2], dtype=torch.int8) + x = torch.ops.quantized_decomposed.quantize_per_channel( + x, scale, zero_point, 2, -128, 127, torch.int8) + x = torch.ops.quantized_decomposed.dequantize_per_channel( + x, scale, zero_point, 2, -128, 127, torch.int8) + stablehlo_txt = xm.get_stablehlo([x]) + self.assertEqual( + stablehlo_txt.count( + 'tensor<2x3x4x5x!quant.uniform>' + ), 2) + self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1) + self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) + + def test_resnet18(self): + # Step 1: export resnet18 + args = (torch.randn(1, 3, 224, 224),) + m = torchvision.models.resnet18().eval() + m = capture_pre_autograd_graph(m, args) + + # Step 2: Insert observers or fake quantize modules + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config()) + m = prepare_pt2e(m, quantizer) + + # Step 3: Quantize the model + m = convert_pt2e(m) + + # Trace with torch/xla and export stablehlo + exported = torch.export.export(m, args) + stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported) + stablehlo_txt = stablehlo_gm.get_stablehlo_text() + fx_node_cnt = count_qdq_ops(exported.graph_module.graph) + self.assertEqual( + stablehlo_txt.count("stablehlo.uniform_quantize"), + fx_node_cnt["qunatize"]) + self.assertEqual( + stablehlo_txt.count("stablehlo.uniform_dequantize"), + fx_node_cnt["dequantize"]) + # Save as tf.saved_model + save_path = '/tmp/tf_saved_model/tmp1' + save_torch_module_as_tf_saved_model(m, args, save_path) + self.assertTrue(os.path.exists(os.path.join(save_path, 'saved_model.pb'))) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index b18014ab2dfb..0375a2dbc607 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -49,6 +49,7 @@ ptxla_cc_library( "nll_loss.cpp", "nms_op.cpp", "pooling.cpp", + "quant_util.cpp", "random.cpp", "reduction.cpp", "resize_ops.cpp", @@ -88,6 +89,7 @@ ptxla_cc_library( "nll_loss.h", "nms_op.h", "pooling.h", + "quant_util.h", "random.h", "reduction.h", "resize_ops.h", diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4e0b524761a5..7efe73a43af9 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -199,6 +199,28 @@ at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input, return bridge::AtenFromXlaTensor(std::move(result)); } +at::Tensor QuantizeTensor(const at::Tensor& input, + const std::vector& scale_list, + const std::vector& zero_point_list, + int quant_min, int quant_max, + const std::string& dtype, int axis) { + auto result = tensor_methods::quantize_tensor( + bridge::GetXlaTensor(input), scale_list, zero_point_list, quant_min, + quant_max, dtype, axis); + return bridge::AtenFromXlaTensor(std::move(result)); +} + +at::Tensor DequantizeTensor(const at::Tensor& input, + const std::vector& scale_list, + const std::vector& zero_point_list, + int quant_min, int quant_max, + const std::string& dtype, int axis) { + auto result = tensor_methods::dequantize_tensor( + bridge::GetXlaTensor(input), scale_list, zero_point_list, quant_min, + quant_max, dtype, axis); + return bridge::AtenFromXlaTensor(std::move(result)); +} + std::pair> ReduceScatter( const std::string& reduce_type, const at::Tensor& input, const std::shared_ptr& token, double scale, @@ -1112,6 +1134,30 @@ void InitXlaModuleBindings(py::module m) { } return result; }); + m.def("_xla_quantize_tensor", + [](const at::Tensor& input, const std::vector& scale_list, + const std::vector& zero_point_list, int quant_min, + int quant_max, const std::string& dtype, int axis) -> at::Tensor { + at::Tensor result; + { + NoGilSection nogil; + result = QuantizeTensor(input, scale_list, zero_point_list, + quant_min, quant_max, dtype, axis); + } + return result; + }); + m.def("_xla_dequantize_tensor", + [](const at::Tensor& input, const std::vector& scale_list, + const std::vector& zero_point_list, int quant_min, + int quant_max, const std::string& dtype, int axis) -> at::Tensor { + at::Tensor result; + { + NoGilSection nogil; + result = DequantizeTensor(input, scale_list, zero_point_list, + quant_min, quant_max, dtype, axis); + } + return result; + }); m.def("_xla_all_to_all", [](const at::Tensor& input, const std::shared_ptr& token, diff --git a/torch_xla/csrc/ir_dump_util.cpp b/torch_xla/csrc/ir_dump_util.cpp index a1998c2a2b88..448cbf63d276 100644 --- a/torch_xla/csrc/ir_dump_util.cpp +++ b/torch_xla/csrc/ir_dump_util.cpp @@ -288,11 +288,11 @@ std::string DumpUtil::ToHlo(c10::ArrayRef values, case EmitMode::kHloReadable: return ConsumeValue(runtime::util::GetComputationHloText(computation)); case EmitMode::kStableHloReadable: - return runtime::hloToStablehlo(&computation.proto(), - /* emit_bytecode = */ false); + return hloToStablehlo(&computation.proto(), + /* emit_bytecode = */ false); case EmitMode::kStableHloBytecode: - return runtime::hloToStablehlo(&computation.proto(), - /* emit_bytecode = */ true); + return hloToStablehlo(&computation.proto(), + /* emit_bytecode = */ true); } } diff --git a/torch_xla/csrc/ops/dequant_tensor.cpp b/torch_xla/csrc/ops/dequant_tensor.cpp new file mode 100644 index 000000000000..e6f11e7bdef3 --- /dev/null +++ b/torch_xla/csrc/ops/dequant_tensor.cpp @@ -0,0 +1,61 @@ +#include "torch_xla/csrc/ops/dequant_tensor.h" + +#include + +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/quant_util.h" +#include "torch_xla/csrc/shape_helper.h" + +namespace torch_xla { + +DequantizeTensor::DequantizeTensor(const torch::lazy::Value& input, + const std::vector& scale, + const std::vector& zero_point, + int quant_min, int quant_max, + const std::string& dtype, int axis) + : XlaNode( + xla_dequantize_tensor, {input}, + GetXlaShape(input) /* fix when quant type is added to HLO */, + /*num_outputs=*/1, + torch::lazy::MHash(scale, zero_point, quant_min, quant_max, dtype)), + quant_min_(quant_min), + quant_max_(quant_max), + axis_(axis), + dtype_(dtype), + scale_(scale), + zero_point_(zero_point) {} + +torch::lazy::NodePtr DequantizeTensor::Clone( + torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands.at(0), scale_, + zero_point_, quant_min_, + quant_max_, dtype_, axis_); +} + +XlaOpVector DequantizeTensor::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::Shape input_shape = ShapeHelper::ShapeOfXlaOp(input); + // TODO(lsy323): Lower to HLO directly once qdtype is added to HLO. + static const std::string opname = "stablehlo.uniform_dequantize"; + auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_, + dtype_, input_shape.element_type()); + + xla::XlaOp output = xla::CustomCall( + input.builder(), opname, {input}, input_shape, + qparams.SerializeToAttrDictStr(), + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/xla::CustomCallApiVersion::API_VERSION_TYPED_FFI); + return ReturnOp(output, loctx); +} + +std::string DequantizeTensor::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", quant_min=" << quant_min_ + << ", quant_max=" << quant_max_ << ", dtype=" << dtype_; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/dequant_tensor.h b/torch_xla/csrc/ops/dequant_tensor.h new file mode 100644 index 000000000000..6b23107e1e44 --- /dev/null +++ b/torch_xla/csrc/ops/dequant_tensor.h @@ -0,0 +1,32 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_DEQUANT_TENSOR_H_ +#define XLA_TORCH_XLA_CSRC_OPS_DEQUANT_TENSOR_H_ + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class DequantizeTensor : public XlaNode { + public: + DequantizeTensor(const torch::lazy::Value& input, + const std::vector& scale, + const std::vector& zero_point, int quant_min, + int quant_max, const std::string& dtype, int axis); + + std::string ToString() const override; + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + int quant_min_; + int quant_max_; + int axis_; + std::string dtype_; + std::vector scale_; + std::vector zero_point_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ diff --git a/torch_xla/csrc/ops/quant_tensor.cpp b/torch_xla/csrc/ops/quant_tensor.cpp new file mode 100644 index 000000000000..26e975790e5e --- /dev/null +++ b/torch_xla/csrc/ops/quant_tensor.cpp @@ -0,0 +1,60 @@ +#include "torch_xla/csrc/ops/quant_tensor.h" + +#include + +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/quant_util.h" +#include "torch_xla/csrc/shape_helper.h" + +namespace torch_xla { + +QuantizeTensor::QuantizeTensor(const torch::lazy::Value& input, + const std::vector& scale, + const std::vector& zero_point, + int quant_min, int quant_max, + const std::string& dtype, int axis) + : XlaNode( + xla_quantize_tensor, {input}, + GetXlaShape(input) /* fix when quant type is added to HLO */, + /*num_outputs=*/1, + torch::lazy::MHash(scale, zero_point, quant_min, quant_max, dtype)), + quant_min_(quant_min), + quant_max_(quant_max), + axis_(axis), + dtype_(dtype), + scale_(scale), + zero_point_(zero_point) {} + +torch::lazy::NodePtr QuantizeTensor::Clone(torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands.at(0), scale_, + zero_point_, quant_min_, + quant_max_, dtype_, axis_); +} + +XlaOpVector QuantizeTensor::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::Shape input_shape = ShapeHelper::ShapeOfXlaOp(input); + // TODO(lsy323): Lower to HLO directly once qdtype is added to HLO. + static const std::string opname = "stablehlo.uniform_quantize"; + auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_, + dtype_, input_shape.element_type()); + + xla::XlaOp output = xla::CustomCall( + input.builder(), opname, {input}, input_shape, + qparams.SerializeToAttrDictStr(), + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/xla::CustomCallApiVersion::API_VERSION_TYPED_FFI); + return ReturnOp(output, loctx); +} + +std::string QuantizeTensor::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", quant_min=" << quant_min_ + << ", quant_max=" << quant_max_ << ", dtype=" << dtype_; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/quant_tensor.h b/torch_xla/csrc/ops/quant_tensor.h new file mode 100644 index 000000000000..2ae5c7caa2b7 --- /dev/null +++ b/torch_xla/csrc/ops/quant_tensor.h @@ -0,0 +1,32 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ +#define XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class QuantizeTensor : public XlaNode { + public: + QuantizeTensor(const torch::lazy::Value& input, + const std::vector& scale, + const std::vector& zero_point, int quant_min, + int quant_max, const std::string& dtype, int axis); + + std::string ToString() const override; + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + int quant_min_; + int quant_max_; + int axis_; + std::string dtype_; + std::vector scale_; + std::vector zero_point_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_QUANT_TENSOR_H_ diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index fa106e5849ca..a9a7f9d62fcc 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -10,6 +10,7 @@ const OpKindWrapper xla_cast("xla::cast"); const OpKindWrapper xla_collective_permute("xla::collective_permute"); const OpKindWrapper xla_cross_replica_sum("xla::cross_replica_sum"); const OpKindWrapper xla_device_data("xla::device_data"); +const OpKindWrapper xla_dequantize_tensor("xla::dequantize_tensor"); const OpKindWrapper xla_diagonal_view_update("xla::diagonal_view_update"); const OpKindWrapper xla_einsum_backward("xla::einsum_backward"); const OpKindWrapper xla_generic_slice("xla::generic_slice"); @@ -18,6 +19,7 @@ const OpKindWrapper xla_moving_average("xla::moving_average"); const OpKindWrapper xla_nms("xla::nms"); const OpKindWrapper xla_not_supported("xla::not_supported"); const OpKindWrapper xla_optimization_barrier("xla::optimization_barrier"); +const OpKindWrapper xla_quantize_tensor("xla::quantize_tensor"); const OpKindWrapper xla_recv("xla::recv"); const OpKindWrapper xla_reduce_scatter("xla::reduce_scatter"); const OpKindWrapper xla_replication_pad("xla::replication_pad"); diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index fa8082b978ac..f39227dd6dd0 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -36,6 +36,7 @@ extern const OpKindWrapper xla_cast; extern const OpKindWrapper xla_collective_permute; extern const OpKindWrapper xla_cross_replica_sum; extern const OpKindWrapper xla_device_data; +extern const OpKindWrapper xla_dequantize_tensor; extern const OpKindWrapper xla_diagonal_view_update; extern const OpKindWrapper xla_einsum_backward; extern const OpKindWrapper xla_generic_slice; @@ -44,6 +45,7 @@ extern const OpKindWrapper xla_moving_average; extern const OpKindWrapper xla_nms; extern const OpKindWrapper xla_not_supported; extern const OpKindWrapper xla_optimization_barrier; +extern const OpKindWrapper xla_quantize_tensor; extern const OpKindWrapper xla_recv; extern const OpKindWrapper xla_reduce_scatter; extern const OpKindWrapper xla_replication_pad; diff --git a/torch_xla/csrc/quant_util.cpp b/torch_xla/csrc/quant_util.cpp new file mode 100644 index 000000000000..e836d286cf73 --- /dev/null +++ b/torch_xla/csrc/quant_util.cpp @@ -0,0 +1,56 @@ +#include "torch_xla/csrc/quant_util.h" + +#include +#include +#include + +#include "torch_xla/csrc/runtime/stablehlo_helper.h" + +namespace torch_xla { + +static inline std::string MaybeAppendDecimalForInteger(float v) { + std::stringstream ss; + if (static_cast(v) == v) { + ss << std::fixed << std::setprecision(2); + } + ss << v; + return ss.str(); +} + +template +static std::string SeralizeFloatVector(const std::vector& v, + bool append_decimal = false) { + std::stringstream ss; + ss << '['; + for (size_t i = 0; i < v.size(); ++i) { + if (i != 0) { + ss << ','; + } + if (append_decimal) { + ss << MaybeAppendDecimalForInteger(v.at(i)); + } else { + ss << v.at(i); + } + } + ss << ']'; + return ss.str(); +} + +std::string QuantParams::SerializeToAttrDictStr() const { + std::stringstream ss; + ss << "{"; + if (axis != -1) { + ss << "quantization_dimension=" << axis << ','; + } + ss << "scale=" << SeralizeFloatVector(scale, true) << ','; + ss << "zero_point=" << SeralizeFloatVector(zero_point) << ','; + ss << "storage_type=" << GetTorchDtypeToStablehloDtypeMap().at(dtype) << ','; + ss << "expressed_type=" << GetHloDtypeToStablehloDtypeMap().at(expressed_type) + << ','; + ss << "storage_min=" << quant_min << ','; + ss << "storage_max=" << quant_max; + ss << '}'; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/quant_util.h b/torch_xla/csrc/quant_util.h new file mode 100644 index 000000000000..dbd67ea9ac80 --- /dev/null +++ b/torch_xla/csrc/quant_util.h @@ -0,0 +1,39 @@ +#ifndef XLA_TORCH_XLA_CSRC_QUANT_UTIL_H_ +#define XLA_TORCH_XLA_CSRC_QUANT_UTIL_H_ + +#include +#include +#include + +#include "xla/primitive_util.h" + +namespace torch_xla { + +// Struct for quantization parameters, for per-tensor/channel quant/dequant ops. +struct QuantParams { + std::vector scale; + std::vector zero_point; + int quant_min; + int quant_max; + int axis; + std::string dtype; + xla::PrimitiveType expressed_type; + + QuantParams(const std::vector& scale, + const std::vector& zero_point, int quant_min, int quant_max, + int axis, std::string dtype, xla::PrimitiveType expressed_type) + : scale(scale), + zero_point(zero_point), + quant_min(quant_min), + quant_max(quant_max), + axis(axis), + dtype(dtype), + expressed_type(expressed_type) {} + + // TODO(lsy323): Remove when qdtype is added in XLA. + std::string SerializeToAttrDictStr() const; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_QUANT_UTIL_H_ diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index be201333c322..fa7e3578729a 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -246,6 +246,7 @@ cc_library( deps = [ ":types", ":xla_util", + "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", "@xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 57fbb3cc8613..f9e46dce55de 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -542,8 +542,7 @@ std::vector PjRtComputationClient::Compile( mlir::MLIRContext context; mlir::ModuleOp mlir_module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); - torch_xla::runtime::ConvertHloToStableHlo( - instance.computation.mutable_proto(), &mlir_module); + ConvertHloToStableHlo(instance.computation.mutable_proto(), &mlir_module); executable = ConsumeValue(client_->Compile(mlir_module, compile_options)); StableHloCompileCounter()->AddValue(1); } else { diff --git a/torch_xla/csrc/runtime/stablehlo_helper.cc b/torch_xla/csrc/runtime/stablehlo_helper.cc index c8593e3be3cf..735e6cb474ff 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.cc +++ b/torch_xla/csrc/runtime/stablehlo_helper.cc @@ -5,6 +5,7 @@ #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" +#include "stablehlo/api/PortableApi.h" // from @stablehlo #include "stablehlo/dialect/Serialization.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo @@ -16,7 +17,6 @@ #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" namespace torch_xla { -namespace runtime { static std::string getHloModuleStr(const xla::HloModuleProto* proto) { auto hlo_module = torch_xla::runtime::util::CreateModuleFromProto(*proto); @@ -37,13 +37,23 @@ static std::string getMlirModuleStr(mlir::ModuleOp& mlir_module) { return txt_mlir_module; } -static std::string getMlirModuleBytecode(const mlir::ModuleOp& mlir_module) { +static std::string getMlirModuleBytecode(mlir::ModuleOp& mlir_module) { + static bool from_pretty_print = runtime::sys_util::GetEnvBool( + "STABLEHLO_BYTECODE_FROM_PRETTYPRINT", false); std::string txt_mlir_module; llvm::raw_string_ostream os{txt_mlir_module}; // TODO(lsiyuan): get the highest StableHLO version from runtime. - auto result = mlir::stablehlo::serializePortableArtifact( - mlir_module, /* target_version = */ "0.14.1", os); - XLA_CHECK(result.succeeded()) << "Serializing StableHLO Failed"; + const std::string stablehlo_version = "0.14.23"; + if (!from_pretty_print) { + auto result = mlir::stablehlo::serializePortableArtifact( + mlir_module, /* target_version = */ stablehlo_version, os); + XLA_CHECK(result.succeeded()) << "Serializing StableHLO Failed"; + } else { + std::string pretty_print_txt = getMlirModuleStr(mlir_module); + auto result = mlir::stablehlo::serializePortableArtifact( + pretty_print_txt, /* target_version = */ stablehlo_version, os); + XLA_CHECK(result.succeeded()) << "Serializing StableHLO Failed"; + } return txt_mlir_module; } @@ -151,5 +161,28 @@ void ConvertStableHloToHlo(mlir::ModuleOp* mlir_module, << getMlirModuleStr(*mlir_module); } -} // namespace runtime +const std::unordered_map& +GetTorchDtypeToStablehloDtypeMap() { + static const std::unordered_map m_{ + {"torch.int8", "si8"}, + {"torch.uint8", "ui8"}, + {"torch.int16", "si16"}, + }; + return m_; +} + +const std::unordered_map& +GetHloDtypeToStablehloDtypeMap() { + static const std::unordered_map m_{ + {xla::PrimitiveType::S4, "si4"}, {xla::PrimitiveType::S8, "si8"}, + {xla::PrimitiveType::S16, "si16"}, {xla::PrimitiveType::S32, "si32"}, + {xla::PrimitiveType::S64, "si64"}, {xla::PrimitiveType::U4, "ui4"}, + {xla::PrimitiveType::U8, "ui8"}, {xla::PrimitiveType::U16, "ui16"}, + {xla::PrimitiveType::U32, "ui32"}, {xla::PrimitiveType::U64, "ui64"}, + {xla::PrimitiveType::F16, "f16"}, {xla::PrimitiveType::BF16, "bf16"}, + {xla::PrimitiveType::F32, "f32"}, + }; + return m_; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/stablehlo_helper.h b/torch_xla/csrc/runtime/stablehlo_helper.h index 359fce4ad62b..232487267ddf 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.h +++ b/torch_xla/csrc/runtime/stablehlo_helper.h @@ -9,7 +9,6 @@ class MLIRContext; } // namespace mlir namespace torch_xla { -namespace runtime { std::string hloToStablehlo(const xla::HloModuleProto* proto, bool emit_bytecode); @@ -23,7 +22,12 @@ void ConvertStableHloToHlo(mlir::ModuleOp* mlir_module, std::string GetHloModuleStr(const xla::HloModuleProto* proto); -} // namespace runtime +const std::unordered_map& +GetTorchDtypeToStablehloDtypeMap(); + +const std::unordered_map& +GetHloDtypeToStablehloDtypeMap(); + } // namespace torch_xla #endif diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 20890d0be379..30bddfa47adc 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -41,6 +41,7 @@ #include "torch_xla/csrc/ops/cumprod.h" #include "torch_xla/csrc/ops/cumsum.h" #include "torch_xla/csrc/ops/custom_sharding.h" +#include "torch_xla/csrc/ops/dequant_tensor.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/diagonal.h" #include "torch_xla/csrc/ops/discrete_uniform.h" @@ -91,6 +92,7 @@ #include "torch_xla/csrc/ops/prod.h" #include "torch_xla/csrc/ops/put.h" #include "torch_xla/csrc/ops/qr.h" +#include "torch_xla/csrc/ops/quant_tensor.h" #include "torch_xla/csrc/ops/recv.h" #include "torch_xla/csrc/ops/reduce_scatter.h" #include "torch_xla/csrc/ops/reflection_pad2d.h" @@ -2107,6 +2109,28 @@ std::tuple qr(const XLATensorPtr& input, input->CreateFrom(torch::lazy::Value(node, 1))); } +XLATensorPtr quantize_tensor(const XLATensorPtr& input, + const std::vector& scale_list, + const std::vector& zero_point_list, + int quant_min, int quant_max, + const std::string& dtype, int axis) { + torch::lazy::NodePtr node = torch::lazy::MakeNode( + input->GetIrValue(), scale_list, zero_point_list, quant_min, quant_max, + dtype, axis); + return input->CreateFrom(torch::lazy::Value(node)); +} + +XLATensorPtr dequantize_tensor(const XLATensorPtr& input, + const std::vector& scale_list, + const std::vector& zero_point_list, + int quant_min, int quant_max, + const std::string& dtype, int axis) { + torch::lazy::NodePtr node = torch::lazy::MakeNode( + input->GetIrValue(), scale_list, zero_point_list, quant_min, quant_max, + dtype, axis); + return input->CreateFrom(torch::lazy::Value(node)); +} + void random_(XLATensorPtr& input, int64_t from, int64_t to) { XLA_CHECK_LE(from, to); auto input_shape = input->shape(); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 5a714170300e..e63d267e206c 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -85,6 +85,21 @@ std::vector user_computation( const std::string& opname, absl::Span inputs, runtime::ComputationClient::ComputationPtr computation); +////////////////////////////////////////////////////////////////////////////// +// Quantization related ops here. +////////////////////////////////////////////////////////////////////////////// +XLATensorPtr quantize_tensor(const XLATensorPtr& input, + const std::vector& scale_list, + const std::vector& zero_point_list, + int quant_min, int quant_max, + const std::string& dtype, int axis); + +XLATensorPtr dequantize_tensor(const XLATensorPtr& input, + const std::vector& scale_list, + const std::vector& zero_point_list, + int quant_min, int quant_max, + const std::string& dtype, int axis); + ////////////////////////////////////////////////////////////////////////////// // ATEN operators follows here, listed in alphabetical order. ////////////////////////////////////////////////////////////////////////////// diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 0033176a172e..b02ce64fd9c2 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -773,7 +773,7 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( mlir::stablehlo::deserializePortableArtifact(bytecode, &context); mlir::ModuleOp mlir_module = *module; xla::HloProto hlo_proto; - runtime::ConvertStableHloToHlo(&mlir_module, &context, &hlo_proto); + ConvertStableHloToHlo(&mlir_module, &context, &hlo_proto); xla::HloModuleProto* hlo_module_proto = hlo_proto.mutable_hlo_module(); xla::XlaComputation computation(*hlo_module_proto); diff --git a/torch_xla/experimental/quantized.py b/torch_xla/experimental/quantized.py new file mode 100644 index 000000000000..2ab48ebf13ed --- /dev/null +++ b/torch_xla/experimental/quantized.py @@ -0,0 +1,100 @@ +import numpy as np +import torch +import torch_xla +from torch.library import Library, impl + +quantized_decomposed_lib = Library("quantized_decomposed", "IMPL") + + +@impl(quantized_decomposed_lib, "quantize_per_tensor", "XLA") +def xla_quantize_per_tensor(input: torch.Tensor, scale: float, zero_point: int, + quant_min: int, quant_max: int, dtype: torch.dtype): + return _xla_quantize(input, torch.tensor([scale]), + torch.tensor([zero_point], dtype=dtype), quant_min, + quant_max, dtype) + + +@impl(quantized_decomposed_lib, "quantize_per_channel", "XLA") +def xla_quantize_per_channel(input: torch.Tensor, scale: torch.Tensor, + zero_point: torch.Tensor, axis: int, + quant_min: int, quant_max: int, + dtype: torch.dtype): + return _xla_quantize(input, scale, zero_point, quant_min, quant_max, dtype, + axis) + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor", "XLA") +def xla_dequantize_per_tensor(input: torch.Tensor, scale: float, + zero_point: int, quant_min: int, quant_max: int, + dtype: torch.dtype): + return _xla_dequantize(input, torch.tensor([scale]), + torch.tensor([zero_point], dtype=dtype), quant_min, + quant_max, dtype) + + +@impl(quantized_decomposed_lib, "dequantize_per_channel", "XLA") +def xla_dequantize_per_tensor(input: torch.Tensor, scale: torch.Tensor, + zero_point: torch.Tensor, axis: int, + quant_min: int, quant_max: int, + dtype: torch.dtype): + return _xla_dequantize(input, scale, zero_point, quant_min, quant_max, dtype, + axis) + + +def _unpack_tensor_to_list(t: torch.Tensor): + if t.device.type == 'xla': + return t.cpu().numpy().tolist() + else: + return t.numpy().tolist() + + +def _check_scale_zp(input, scale, zero_point, axis, dtype): + # The followings are checked: + # 1. scale, zp are 1D tensor. + # 2. Lenghth of scale, zp matched the (de)quant dim. + # 3. zp dtype is the same as the quantized integer type. + assert len(scale.shape) == 1 and len(zero_point.shape) == 1 + assert zero_point.dtype == dtype + if axis == -1: + assert scale.numel() == 1 and zero_point.numel() == 1 + else: + assert axis >= 0 and axis < len(input.shape) + qdq_dim_size = input.shape[axis] + assert qdq_dim_size == scale.numel() and qdq_dim_size == zero_point.numel() + + +def _xla_quantize(input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + axis: int = -1): + _check_scale_zp(input, scale, zero_point, axis, dtype) + # Scale and zero_point need to be unpacked(materialized before enter LTC), + # because the quant param will be attached to tensor Shape in HLO/StableHLO. + scale_np = _unpack_tensor_to_list(scale) + zp_np = _unpack_tensor_to_list(zero_point) + # All scaler values needs to be greater than 0. (StableHLO qdq op constraint) + assert np.all(np.greater(scale_np, 0)) + return torch_xla._XLAC._xla_quantize_tensor(input, scale_np, zp_np, quant_min, + quant_max, str(dtype), axis) + + +def _xla_dequantize(input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + axis: int = -1): + _check_scale_zp(input, scale, zero_point, axis, dtype) + # Scale and zero_point need to be unpacked(materialized before enter LTC), + # because the quant param will be attached to tensor Shape in HLO/StableHLO. + scale_np = _unpack_tensor_to_list(scale) + zp_np = _unpack_tensor_to_list(zero_point) + # All scaler values needs to be greater than 0. (StableHLO qdq op constraint) + assert np.all(np.greater(scale_np, 0)) + return torch_xla._XLAC._xla_dequantize_tensor(input, scale_np, + zp_np, quant_min, quant_max, + str(dtype), axis) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index b03f92528e8b..d1eb267721f8 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -17,6 +17,7 @@ from torch_xla.core import xla_model as xm from torch_xla.core import dynamo_bridge from torch_xla.debug import metrics +import torch_xla.experimental.quantized import torch._dynamo as torchdynamo from torch.utils import _pytree as pytree @@ -356,7 +357,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, output_signature = [ VariableSignature( shape=list(tensor.shape), - dtype=str(tensor_value.dtype).replace('torch.', '')) for tensor in res + dtype=str(tensor.dtype).replace('torch.', '')) for tensor in res ] torch_xla._XLAC._clear_pending_irs(str(device))