forked from pytorch/xla
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Lower quant/dequant torch op to StableHLO (pytorch#5763)
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
- Loading branch information
Showing
24 changed files
with
826 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
// TODO(lsy323): This is a patch on the HLO->StableHLO converter, this allows the custom call to | ||
// stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize. | ||
// The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO. | ||
diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD | ||
index f74973ae1..8e3f0e06b 100644 | ||
--- a/xla/translate/hlo_to_mhlo/BUILD | ||
+++ b/xla/translate/hlo_to_mhlo/BUILD | ||
@@ -67,6 +67,7 @@ cc_library( | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:AsmParser", | ||
"@llvm-project//mlir:FuncDialect", | ||
+ "@llvm-project//mlir:QuantOps", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:SparseTensorDialect", | ||
"@tsl//tsl/platform:statusor", | ||
diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc | ||
index 08d5f49c8..2f9ad1e0b 100644 | ||
--- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc | ||
+++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc | ||
@@ -664,6 +664,70 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction( | ||
return importer.ImportInstructionWithLayout(instr, operands, builder, mode); | ||
} | ||
|
||
+Type getQuantizedType(mlir::DictionaryAttr& backend_config) { | ||
+ std::vector<double> scales; | ||
+ std::vector<int64_t> zero_points; | ||
+ int64_t quantization_dimension = -1, storage_max = 0, storage_min = 0; | ||
+ Type storage_type, expressed_type; | ||
+ | ||
+ auto scales_attr = backend_config.get("scale"); | ||
+ if (scales_attr) { | ||
+ for (auto scale_attr : scales_attr.cast<mlir::ArrayAttr>()) { | ||
+ scales.push_back(scale_attr.cast<mlir::FloatAttr>().getValueAsDouble()); | ||
+ } | ||
+ } | ||
+ | ||
+ auto zero_points_attr = backend_config.get("zero_point"); | ||
+ if (zero_points_attr) { | ||
+ for (auto zero_point_attr : zero_points_attr.cast<mlir::ArrayAttr>()) { | ||
+ zero_points.push_back(zero_point_attr.cast<mlir::IntegerAttr>().getInt()); | ||
+ } | ||
+ } | ||
+ | ||
+ auto quantization_dimension_attr = | ||
+ backend_config.get("quantization_dimension"); | ||
+ if (quantization_dimension_attr) { | ||
+ quantization_dimension = | ||
+ quantization_dimension_attr.cast<mlir::IntegerAttr>().getInt(); | ||
+ } | ||
+ | ||
+ auto storage_max_attr = backend_config.get("storage_max"); | ||
+ if (storage_max_attr) { | ||
+ storage_max = storage_max_attr.cast<mlir::IntegerAttr>().getInt(); | ||
+ } | ||
+ | ||
+ auto storage_min_attr = backend_config.get("storage_min"); | ||
+ if (storage_min_attr) { | ||
+ storage_min = storage_min_attr.cast<mlir::IntegerAttr>().getInt(); | ||
+ } | ||
+ | ||
+ auto storage_type_attr = backend_config.get("storage_type"); | ||
+ if (storage_type_attr) { | ||
+ storage_type = storage_type_attr.cast<mlir::TypeAttr>().getValue(); | ||
+ //.cast<mlir::ShapedType>() | ||
+ //.getElementType(); | ||
+ } | ||
+ | ||
+ auto expressed_type_attr = backend_config.get("expressed_type"); | ||
+ if (expressed_type_attr) { | ||
+ expressed_type = expressed_type_attr.cast<mlir::TypeAttr>().getValue(); | ||
+ //.cast<mlir::ShapedType>() | ||
+ //.getElementType(); | ||
+ } | ||
+ | ||
+ auto is_signed = storage_type.cast<mlir::IntegerType>().isSigned(); | ||
+ | ||
+ if (quantization_dimension != -1) { | ||
+ return mlir::quant::UniformQuantizedPerAxisType::get( | ||
+ is_signed, storage_type, expressed_type, scales, zero_points, | ||
+ quantization_dimension, storage_min, storage_max); | ||
+ } else { | ||
+ return mlir::quant::UniformQuantizedType::get( | ||
+ is_signed, storage_type, expressed_type, scales[0], zero_points[0], | ||
+ storage_min, storage_max); | ||
+ } | ||
+} | ||
+ | ||
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl( | ||
const HloInstruction* instruction, | ||
const llvm::SmallVectorImpl<mlir::Value>& operands, | ||
@@ -933,6 +997,25 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl( | ||
"Couldn't parse backend config into a dictionary attribute"); | ||
|
||
attributes.push_back(builder_->getNamedAttr("backend_config", attr)); | ||
+ auto backend_config = attr.cast<mlir::DictionaryAttr>(); | ||
+ if (custom_call->custom_call_target() == | ||
+ "stablehlo.uniform_quantize") { | ||
+ return func_builder | ||
+ ->create<mlir::mhlo::UniformQuantizeOp>( | ||
+ loc, | ||
+ mlir::RankedTensorType::get( | ||
+ result_type.cast<RankedTensorType>().getShape(), | ||
+ getQuantizedType(backend_config)), | ||
+ operands) | ||
+ .getOperation(); | ||
+ } | ||
+ | ||
+ if (custom_call->custom_call_target() == | ||
+ "stablehlo.uniform_dequantize") { | ||
+ return func_builder | ||
+ ->create<mlir::mhlo::UniformDequantizeOp>( | ||
+ loc, result_type, operands) .getOperation(); | ||
+ } | ||
} | ||
} else { | ||
attributes.push_back(builder_->getNamedAttr( | ||
diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc | ||
index 9f05992c8..03cf4840d 100644 | ||
--- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc | ||
+++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc | ||
@@ -19,6 +19,8 @@ limitations under the License. | ||
#include <memory> | ||
#include <vector> | ||
|
||
+#include "mlir/Dialect/Quant/QuantOps.h" | ||
+#include "mlir/Dialect/Quant/QuantTypes.h" | ||
#include "mlir/IR/Attributes.h" // from @llvm-project | ||
#include "xla/hlo/ir/hlo_computation.h" | ||
#include "xla/hlo/ir/hlo_instruction.h" | ||
@@ -41,6 +43,7 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module, | ||
module.getContext()->loadDialect<mlir::arith::ArithDialect>(); | ||
module.getContext()->loadDialect<mlir::func::FuncDialect>(); | ||
module.getContext()->loadDialect<mlir::mhlo::MhloDialect>(); | ||
+ module.getContext()->loadDialect<mlir::quant::QuantizationDialect>(); | ||
} | ||
|
||
namespace { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// TODO(lsy323): This patch is needed to serialize stablehlo.uniform_quantize/dequantize in bytecode format | ||
// This patch can be removed after https://github.com/openxla/stablehlo/issues/1812 is fixed. | ||
diff --git a/third_party/stablehlo/stablehlo_quant_seralization.patch b/third_party/stablehlo/stablehlo_quant_seralization.patch | ||
new file mode 100644 | ||
index 000000000..24e23b67d | ||
--- /dev/null | ||
+++ b/third_party/stablehlo/stablehlo_quant_seralization.patch | ||
@@ -0,0 +1,26 @@ | ||
+diff --git a/stablehlo/api/PortableApi.cpp b/stablehlo/api/PortableApi.cpp | ||
+index 07c856db..cd169cae 100644 | ||
+--- a/stablehlo/api/PortableApi.cpp | ||
++++ b/stablehlo/api/PortableApi.cpp | ||
+@@ -15,10 +15,13 @@ limitations under the License. | ||
+ | ||
+ #include "stablehlo/api/PortableApi.h" | ||
+ | ||
++#include <iostream> | ||
+ #include <string> | ||
+ | ||
+ #include "mlir/Bytecode/BytecodeWriter.h" | ||
+ #include "mlir/Dialect/Func/IR/FuncOps.h" | ||
++#include "mlir/Dialect/Quant/QuantOps.h" | ||
++#include "mlir/Dialect/Quant/QuantTypes.h" | ||
+ #include "mlir/IR/MLIRContext.h" | ||
+ #include "mlir/Parser/Parser.h" | ||
+ #include "stablehlo/dialect/Serialization.h" | ||
+@@ -33,6 +36,7 @@ void loadSerializationDialects(MLIRContext* context) { | ||
+ context->loadDialect<mlir::func::FuncDialect>(); | ||
+ context->loadDialect<mlir::stablehlo::StablehloDialect>(); | ||
+ context->loadDialect<mlir::vhlo::VhloDialect>(); | ||
++ context->loadDialect<mlir::quant::QuantizationDialect>(); | ||
+ } | ||
+ } // namespace | ||
+ | ||
diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl | ||
index 9f4494aac..64fa072bb 100644 | ||
--- a/third_party/stablehlo/workspace.bzl | ||
+++ b/third_party/stablehlo/workspace.bzl | ||
@@ -15,5 +15,6 @@ def repo(): | ||
urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)), | ||
patch_file = [ | ||
"//third_party/stablehlo:temporary.patch", # Autogenerated, don't remove. | ||
+ "//third_party/stablehlo:stablehlo_quant_seralization.patch", # Load quant dialect. | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import os | ||
import unittest | ||
from typing import Callable, Dict, List | ||
|
||
import torch | ||
import torch_xla.core.xla_model as xm | ||
import torchvision | ||
from torch._export import capture_pre_autograd_graph | ||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e | ||
from torch.ao.quantization.quantizer.xnnpack_quantizer import ( | ||
XNNPACKQuantizer, get_symmetric_quantization_config) | ||
from torch_xla import stablehlo | ||
from torch_xla.tf_saved_model_integration import \ | ||
save_torch_module_as_tf_saved_model | ||
|
||
# Needed to workaround the stablehlo bytecode serialization issue in https://github.com/openxla/stablehlo/issues/1812 | ||
os.environ['STABLEHLO_BYTECODE_FROM_PRETTYPRINT'] = '1' | ||
|
||
_TORCH_QUANTIZE_OPS = [ | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default, | ||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor, | ||
torch.ops.quantized_decomposed.quantize_per_channel.default, | ||
] | ||
|
||
_TORCH_DEQUANTIZE_OPS = [ | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, | ||
torch.ops.quantized_decomposed.dequantize_per_channel.default, | ||
] | ||
|
||
|
||
def count_fx_graph_nodes(g: torch.fx.Graph, op_dict: Dict[str, List[Callable]]): | ||
cnt = {} | ||
for k in op_dict.keys(): | ||
cnt[k] = 0 | ||
for node in g.nodes: | ||
for k, ops in op_dict.items(): | ||
if node.target in ops: | ||
cnt[k] += 1 | ||
return cnt | ||
|
||
|
||
def count_qdq_ops(g: torch.fx.Graph): | ||
op_dict = { | ||
"qunatize": _TORCH_QUANTIZE_OPS, | ||
"dequantize": _TORCH_DEQUANTIZE_OPS, | ||
} | ||
return count_fx_graph_nodes(g, op_dict) | ||
|
||
|
||
class PT2EExportTest(unittest.TestCase): | ||
|
||
def test_per_tensor_qdq(self): | ||
device = xm.xla_device() | ||
x = torch.randn(2, 3, 4, 5).to(device) | ||
x = torch.ops.quantized_decomposed.quantize_per_tensor( | ||
x, 0.4, 2, -128, 127, torch.int8) | ||
x = torch.ops.quantized_decomposed.dequantize_per_tensor( | ||
x, 0.4, 2, -128, 127, torch.int8) | ||
stablehlo_txt = xm.get_stablehlo([x]) | ||
self.assertEqual( | ||
stablehlo_txt.count( | ||
'tensor<2x3x4x5x!quant.uniform<i8:f32, 4.000000e-01:2>>'), 2) | ||
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1) | ||
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) | ||
|
||
def test_per_channel_qdq(self): | ||
device = xm.xla_device() | ||
x = torch.randn(2, 3, 4, 5).to(device) | ||
scale = torch.tensor([3.2, 5.3, 0.1, 10]) | ||
zero_point = torch.tensor([1, 2, -1, -2], dtype=torch.int8) | ||
x = torch.ops.quantized_decomposed.quantize_per_channel( | ||
x, scale, zero_point, 2, -128, 127, torch.int8) | ||
x = torch.ops.quantized_decomposed.dequantize_per_channel( | ||
x, scale, zero_point, 2, -128, 127, torch.int8) | ||
stablehlo_txt = xm.get_stablehlo([x]) | ||
self.assertEqual( | ||
stablehlo_txt.count( | ||
'tensor<2x3x4x5x!quant.uniform<i8:f32:2, {3.200000e+00:1,5.300000e+00:2,1.000000e-01:-1,1.000000e+01:-2}>>' | ||
), 2) | ||
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1) | ||
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) | ||
|
||
def test_resnet18(self): | ||
# Step 1: export resnet18 | ||
args = (torch.randn(1, 3, 224, 224),) | ||
m = torchvision.models.resnet18().eval() | ||
m = capture_pre_autograd_graph(m, args) | ||
|
||
# Step 2: Insert observers or fake quantize modules | ||
quantizer = XNNPACKQuantizer().set_global( | ||
get_symmetric_quantization_config()) | ||
m = prepare_pt2e(m, quantizer) | ||
|
||
# Step 3: Quantize the model | ||
m = convert_pt2e(m) | ||
|
||
# Trace with torch/xla and export stablehlo | ||
exported = torch.export.export(m, args) | ||
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported) | ||
stablehlo_txt = stablehlo_gm.get_stablehlo_text() | ||
fx_node_cnt = count_qdq_ops(exported.graph_module.graph) | ||
self.assertEqual( | ||
stablehlo_txt.count("stablehlo.uniform_quantize"), | ||
fx_node_cnt["qunatize"]) | ||
self.assertEqual( | ||
stablehlo_txt.count("stablehlo.uniform_dequantize"), | ||
fx_node_cnt["dequantize"]) | ||
# Save as tf.saved_model | ||
save_path = '/tmp/tf_saved_model/tmp1' | ||
save_torch_module_as_tf_saved_model(m, args, save_path) | ||
self.assertTrue(os.path.exists(os.path.join(save_path, 'saved_model.pb'))) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.