Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower quant/dequant torch op to StableHLO #5763

Merged
merged 23 commits into from
Nov 28, 2023
Merged

Conversation

lsy323
Copy link
Collaborator

@lsy323 lsy323 commented Nov 2, 2023

The following torch ops can be lowered to StableHLO with this diff:

  • quantize_per_tensor
  • quantize_per_channel
  • dequantize_per_tensor
  • dequantize_per_channel

User Experience

  • The GraphModule generated from PT2E quantization can be exported to StableHLO, or tf.saved_model using the existing exporting API without any additional change on model code, or exporting script. STABLEHLO_BYTECODE_FROM_PRETTYPRINT needs to be set to 1 to workaround a StableHLO bytecode serialization issue.

Current workflow

  1. Register xla qdq ops to 'XLA' dispatch key. So the qdq ops will be dispatched to xla impl during LTC tracing.
  2. During lowering, qdq ops are lowered to a custom call to stablehlo.uniform_quantize/dequantize in HLO. The qparams are stored in the custom call config str. The config str can be deserialized to mlir DictAttr directly.
  3. HLO->StableHLO converter will convert custom call to stablehlo.uniform_quantize/dequantize

Changes

  1. Allow save_torch_module_as_tf_saved_model to take GraphModule as well, since PT2E outputs a GraphModule.
  2. Added 2 patches. One is to add support to HLO->StableHLO converter for stablehlo.uniform_quantize/dequantize conversion, originally authored by @sdasgup3. Another is to workaround a StableHLO bytecode serialization issue mentioned above. Both won't be needed if HLO qdtype representation is added.
  3. Added new xla quantize_tensor/dequantize_tensor ops for qdq ops lowering. the xla quantize/dequantize op lowers to custom call to stablehlo.uniform_quantize/dequantize
  4. Test script including exporting per-tensor/channel qdq ops and PT2E quantized resnet18 model.

Future Work

  1. When qdtype is added to HLO, the lowering logic need to be updated and will be more concise than the current one.

cc @sdasgup3 @GleasonK @paulinesho

@lsy323 lsy323 requested review from miladm, qihqi and JackCaoG November 2, 2023 16:59
WORKSPACE Show resolved Hide resolved
@lsy323 lsy323 force-pushed the lsiyuan/quant-dequant-dispatch branch from e70be80 to c329b63 Compare November 18, 2023 00:39
@lsy323 lsy323 requested a review from JackCaoG November 27, 2023 18:19
torch_xla/tf_saved_model_integration.py Outdated Show resolved Hide resolved
torch_xla/csrc/runtime/stablehlo_helper.cc Show resolved Hide resolved
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m = capture_pre_autograd_graph(m, args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we use this instead of torch.export?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I saw the export below, but still confuse what this function does to the module.

Copy link
Collaborator Author

@lsy323 lsy323 Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the graph is captured for PT2E to further process. PT2E doesn't work with graph captured from torch.exported (just tried locally), it needs to capture the graph in this way.

The export down below is for PyTorch -> StableHLO exporting, our API only works on exported program

@lsy323
Copy link
Collaborator Author

lsy323 commented Nov 28, 2023

Update:

  • Addressed review comments
  • Enhanced testing script to check the qparam of qdq stablehlo ops, numbers of qdq ops
  • Added more assertions to the torch_xla qdq ops, including scale, zero_point shape, zero_point dtype matches int dtype of quantized type, scale values are all positive

@lsy323 lsy323 requested a review from JackCaoG November 28, 2023 07:50
@lsy323 lsy323 added the stablehlo StableHLO related work label Nov 28, 2023
@lsy323 lsy323 merged commit a3b0c6e into master Nov 28, 2023
18 checks passed
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
@lsy323 lsy323 deleted the lsiyuan/quant-dequant-dispatch branch March 4, 2024 19:12
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize.
---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stablehlo StableHLO related work
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants