Skip to content

Commit

Permalink
Lower quant/dequant torch op to StableHLO (#5763)
Browse files Browse the repository at this point in the history
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
  • Loading branch information
2 people authored and bhavya01 committed Apr 22, 2024
1 parent b78e3fa commit 3fa33fe
Show file tree
Hide file tree
Showing 24 changed files with 826 additions and 16 deletions.
2 changes: 2 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ http_archive(
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_topk_rewriter.diff",
"//openxla_patches:quant_dequant_converter.diff",
"//openxla_patches:stablehlo_quant_seralization.diff",
],
strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478",
urls = [
Expand Down
137 changes: 137 additions & 0 deletions openxla_patches/quant_dequant_converter.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// TODO(lsy323): This is a patch on the HLO->StableHLO converter, this allows the custom call to
// stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize.
// The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO.
diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD
index f74973ae1..8e3f0e06b 100644
--- a/xla/translate/hlo_to_mhlo/BUILD
+++ b/xla/translate/hlo_to_mhlo/BUILD
@@ -67,6 +67,7 @@ cc_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SparseTensorDialect",
"@tsl//tsl/platform:statusor",
diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
index 08d5f49c8..2f9ad1e0b 100644
--- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
@@ -664,6 +664,70 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
return importer.ImportInstructionWithLayout(instr, operands, builder, mode);
}

+Type getQuantizedType(mlir::DictionaryAttr& backend_config) {
+ std::vector<double> scales;
+ std::vector<int64_t> zero_points;
+ int64_t quantization_dimension = -1, storage_max = 0, storage_min = 0;
+ Type storage_type, expressed_type;
+
+ auto scales_attr = backend_config.get("scale");
+ if (scales_attr) {
+ for (auto scale_attr : scales_attr.cast<mlir::ArrayAttr>()) {
+ scales.push_back(scale_attr.cast<mlir::FloatAttr>().getValueAsDouble());
+ }
+ }
+
+ auto zero_points_attr = backend_config.get("zero_point");
+ if (zero_points_attr) {
+ for (auto zero_point_attr : zero_points_attr.cast<mlir::ArrayAttr>()) {
+ zero_points.push_back(zero_point_attr.cast<mlir::IntegerAttr>().getInt());
+ }
+ }
+
+ auto quantization_dimension_attr =
+ backend_config.get("quantization_dimension");
+ if (quantization_dimension_attr) {
+ quantization_dimension =
+ quantization_dimension_attr.cast<mlir::IntegerAttr>().getInt();
+ }
+
+ auto storage_max_attr = backend_config.get("storage_max");
+ if (storage_max_attr) {
+ storage_max = storage_max_attr.cast<mlir::IntegerAttr>().getInt();
+ }
+
+ auto storage_min_attr = backend_config.get("storage_min");
+ if (storage_min_attr) {
+ storage_min = storage_min_attr.cast<mlir::IntegerAttr>().getInt();
+ }
+
+ auto storage_type_attr = backend_config.get("storage_type");
+ if (storage_type_attr) {
+ storage_type = storage_type_attr.cast<mlir::TypeAttr>().getValue();
+ //.cast<mlir::ShapedType>()
+ //.getElementType();
+ }
+
+ auto expressed_type_attr = backend_config.get("expressed_type");
+ if (expressed_type_attr) {
+ expressed_type = expressed_type_attr.cast<mlir::TypeAttr>().getValue();
+ //.cast<mlir::ShapedType>()
+ //.getElementType();
+ }
+
+ auto is_signed = storage_type.cast<mlir::IntegerType>().isSigned();
+
+ if (quantization_dimension != -1) {
+ return mlir::quant::UniformQuantizedPerAxisType::get(
+ is_signed, storage_type, expressed_type, scales, zero_points,
+ quantization_dimension, storage_min, storage_max);
+ } else {
+ return mlir::quant::UniformQuantizedType::get(
+ is_signed, storage_type, expressed_type, scales[0], zero_points[0],
+ storage_min, storage_max);
+ }
+}
+
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
const HloInstruction* instruction,
const llvm::SmallVectorImpl<mlir::Value>& operands,
@@ -933,6 +997,25 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
"Couldn't parse backend config into a dictionary attribute");

attributes.push_back(builder_->getNamedAttr("backend_config", attr));
+ auto backend_config = attr.cast<mlir::DictionaryAttr>();
+ if (custom_call->custom_call_target() ==
+ "stablehlo.uniform_quantize") {
+ return func_builder
+ ->create<mlir::mhlo::UniformQuantizeOp>(
+ loc,
+ mlir::RankedTensorType::get(
+ result_type.cast<RankedTensorType>().getShape(),
+ getQuantizedType(backend_config)),
+ operands)
+ .getOperation();
+ }
+
+ if (custom_call->custom_call_target() ==
+ "stablehlo.uniform_dequantize") {
+ return func_builder
+ ->create<mlir::mhlo::UniformDequantizeOp>(
+ loc, result_type, operands) .getOperation();
+ }
}
} else {
attributes.push_back(builder_->getNamedAttr(
diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
index 9f05992c8..03cf4840d 100644
--- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <vector>

+#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
@@ -41,6 +43,7 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module,
module.getContext()->loadDialect<mlir::arith::ArithDialect>();
module.getContext()->loadDialect<mlir::func::FuncDialect>();
module.getContext()->loadDialect<mlir::mhlo::MhloDialect>();
+ module.getContext()->loadDialect<mlir::quant::QuantizationDialect>();
}

namespace {
45 changes: 45 additions & 0 deletions openxla_patches/stablehlo_quant_seralization.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// TODO(lsy323): This patch is needed to serialize stablehlo.uniform_quantize/dequantize in bytecode format
// This patch can be removed after https://github.com/openxla/stablehlo/issues/1812 is fixed.
diff --git a/third_party/stablehlo/stablehlo_quant_seralization.patch b/third_party/stablehlo/stablehlo_quant_seralization.patch
new file mode 100644
index 000000000..24e23b67d
--- /dev/null
+++ b/third_party/stablehlo/stablehlo_quant_seralization.patch
@@ -0,0 +1,26 @@
+diff --git a/stablehlo/api/PortableApi.cpp b/stablehlo/api/PortableApi.cpp
+index 07c856db..cd169cae 100644
+--- a/stablehlo/api/PortableApi.cpp
++++ b/stablehlo/api/PortableApi.cpp
+@@ -15,10 +15,13 @@ limitations under the License.
+
+ #include "stablehlo/api/PortableApi.h"
+
++#include <iostream>
+ #include <string>
+
+ #include "mlir/Bytecode/BytecodeWriter.h"
+ #include "mlir/Dialect/Func/IR/FuncOps.h"
++#include "mlir/Dialect/Quant/QuantOps.h"
++#include "mlir/Dialect/Quant/QuantTypes.h"
+ #include "mlir/IR/MLIRContext.h"
+ #include "mlir/Parser/Parser.h"
+ #include "stablehlo/dialect/Serialization.h"
+@@ -33,6 +36,7 @@ void loadSerializationDialects(MLIRContext* context) {
+ context->loadDialect<mlir::func::FuncDialect>();
+ context->loadDialect<mlir::stablehlo::StablehloDialect>();
+ context->loadDialect<mlir::vhlo::VhloDialect>();
++ context->loadDialect<mlir::quant::QuantizationDialect>();
+ }
+ } // namespace
+
diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl
index 9f4494aac..64fa072bb 100644
--- a/third_party/stablehlo/workspace.bzl
+++ b/third_party/stablehlo/workspace.bzl
@@ -15,5 +15,6 @@ def repo():
urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)),
patch_file = [
"//third_party/stablehlo:temporary.patch", # Autogenerated, don't remove.
+ "//third_party/stablehlo:stablehlo_quant_seralization.patch", # Load quant dialect.
],
)
117 changes: 117 additions & 0 deletions test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
import unittest
from typing import Callable, Dict, List

import torch
import torch_xla.core.xla_model as xm
import torchvision
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer, get_symmetric_quantization_config)
from torch_xla import stablehlo
from torch_xla.tf_saved_model_integration import \
save_torch_module_as_tf_saved_model

# Needed to workaround the stablehlo bytecode serialization issue in https://github.com/openxla/stablehlo/issues/1812
os.environ['STABLEHLO_BYTECODE_FROM_PRETTYPRINT'] = '1'

_TORCH_QUANTIZE_OPS = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]

_TORCH_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]


def count_fx_graph_nodes(g: torch.fx.Graph, op_dict: Dict[str, List[Callable]]):
cnt = {}
for k in op_dict.keys():
cnt[k] = 0
for node in g.nodes:
for k, ops in op_dict.items():
if node.target in ops:
cnt[k] += 1
return cnt


def count_qdq_ops(g: torch.fx.Graph):
op_dict = {
"qunatize": _TORCH_QUANTIZE_OPS,
"dequantize": _TORCH_DEQUANTIZE_OPS,
}
return count_fx_graph_nodes(g, op_dict)


class PT2EExportTest(unittest.TestCase):

def test_per_tensor_qdq(self):
device = xm.xla_device()
x = torch.randn(2, 3, 4, 5).to(device)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, 0.4, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, 0.4, 2, -128, 127, torch.int8)
stablehlo_txt = xm.get_stablehlo([x])
self.assertEqual(
stablehlo_txt.count(
'tensor<2x3x4x5x!quant.uniform<i8:f32, 4.000000e-01:2>>'), 2)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1)

def test_per_channel_qdq(self):
device = xm.xla_device()
x = torch.randn(2, 3, 4, 5).to(device)
scale = torch.tensor([3.2, 5.3, 0.1, 10])
zero_point = torch.tensor([1, 2, -1, -2], dtype=torch.int8)
x = torch.ops.quantized_decomposed.quantize_per_channel(
x, scale, zero_point, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_channel(
x, scale, zero_point, 2, -128, 127, torch.int8)
stablehlo_txt = xm.get_stablehlo([x])
self.assertEqual(
stablehlo_txt.count(
'tensor<2x3x4x5x!quant.uniform<i8:f32:2, {3.200000e+00:1,5.300000e+00:2,1.000000e-01:-1,1.000000e+01:-2}>>'
), 2)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1)

def test_resnet18(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m = capture_pre_autograd_graph(m, args)

# Step 2: Insert observers or fake quantize modules
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config())
m = prepare_pt2e(m, quantizer)

# Step 3: Quantize the model
m = convert_pt2e(m)

# Trace with torch/xla and export stablehlo
exported = torch.export.export(m, args)
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported)
stablehlo_txt = stablehlo_gm.get_stablehlo_text()
fx_node_cnt = count_qdq_ops(exported.graph_module.graph)
self.assertEqual(
stablehlo_txt.count("stablehlo.uniform_quantize"),
fx_node_cnt["qunatize"])
self.assertEqual(
stablehlo_txt.count("stablehlo.uniform_dequantize"),
fx_node_cnt["dequantize"])
# Save as tf.saved_model
save_path = '/tmp/tf_saved_model/tmp1'
save_torch_module_as_tf_saved_model(m, args, save_path)
self.assertTrue(os.path.exists(os.path.join(save_path, 'saved_model.pb')))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 2 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ ptxla_cc_library(
"nll_loss.cpp",
"nms_op.cpp",
"pooling.cpp",
"quant_util.cpp",
"random.cpp",
"reduction.cpp",
"resize_ops.cpp",
Expand Down Expand Up @@ -88,6 +89,7 @@ ptxla_cc_library(
"nll_loss.h",
"nms_op.h",
"pooling.h",
"quant_util.h",
"random.h",
"reduction.h",
"resize_ops.h",
Expand Down
46 changes: 46 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,28 @@ at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input,
return bridge::AtenFromXlaTensor(std::move(result));
}

at::Tensor QuantizeTensor(const at::Tensor& input,
const std::vector<float>& scale_list,
const std::vector<int>& zero_point_list,
int quant_min, int quant_max,
const std::string& dtype, int axis) {
auto result = tensor_methods::quantize_tensor(
bridge::GetXlaTensor(input), scale_list, zero_point_list, quant_min,
quant_max, dtype, axis);
return bridge::AtenFromXlaTensor(std::move(result));
}

at::Tensor DequantizeTensor(const at::Tensor& input,
const std::vector<float>& scale_list,
const std::vector<int>& zero_point_list,
int quant_min, int quant_max,
const std::string& dtype, int axis) {
auto result = tensor_methods::dequantize_tensor(
bridge::GetXlaTensor(input), scale_list, zero_point_list, quant_min,
quant_max, dtype, axis);
return bridge::AtenFromXlaTensor(std::move(result));
}

std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> ReduceScatter(
const std::string& reduce_type, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
Expand Down Expand Up @@ -1112,6 +1134,30 @@ void InitXlaModuleBindings(py::module m) {
}
return result;
});
m.def("_xla_quantize_tensor",
[](const at::Tensor& input, const std::vector<float>& scale_list,
const std::vector<int>& zero_point_list, int quant_min,
int quant_max, const std::string& dtype, int axis) -> at::Tensor {
at::Tensor result;
{
NoGilSection nogil;
result = QuantizeTensor(input, scale_list, zero_point_list,
quant_min, quant_max, dtype, axis);
}
return result;
});
m.def("_xla_dequantize_tensor",
[](const at::Tensor& input, const std::vector<float>& scale_list,
const std::vector<int>& zero_point_list, int quant_min,
int quant_max, const std::string& dtype, int axis) -> at::Tensor {
at::Tensor result;
{
NoGilSection nogil;
result = DequantizeTensor(input, scale_list, zero_point_list,
quant_min, quant_max, dtype, axis);
}
return result;
});
m.def("_xla_all_to_all",
[](const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token,
Expand Down
Loading

0 comments on commit 3fa33fe

Please sign in to comment.