-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lower quant/dequant torch op to StableHLO #5763
Conversation
lowering to HLO custom call.
refactor add quant util rename test script
clean up quant op
e70be80
to
c329b63
Compare
# Step 1: export resnet18 | ||
args = (torch.randn(1, 3, 224, 224),) | ||
m = torchvision.models.resnet18().eval() | ||
m = capture_pre_autograd_graph(m, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason we use this instead of torch.export
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I saw the export below, but still confuse what this function does to the module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the graph is captured for PT2E to further process. PT2E doesn't work with graph captured from torch.exported (just tried locally), it needs to capture the graph in this way.
The export down below is for PyTorch -> StableHLO exporting, our API only works on exported program
Update:
|
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
(de)quantize_per_tensor/channel ops from PT2E quantization workflow are lowered to stablehlo uniform_dequantize/quantize. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
The following torch ops can be lowered to StableHLO with this diff:
User Experience
STABLEHLO_BYTECODE_FROM_PRETTYPRINT
needs to be set to 1 to workaround a StableHLO bytecode serialization issue.Current workflow
stablehlo.uniform_quantize/dequantize
in HLO. The qparams are stored in the custom call config str. The config str can be deserialized to mlir DictAttr directly.stablehlo.uniform_quantize/dequantize
Changes
save_torch_module_as_tf_saved_model
to take GraphModule as well, since PT2E outputs a GraphModule.stablehlo.uniform_quantize/dequantize
conversion, originally authored by @sdasgup3. Another is to workaround a StableHLO bytecode serialization issue mentioned above. Both won't be needed if HLO qdtype representation is added.stablehlo.uniform_quantize/dequantize
Future Work
cc @sdasgup3 @GleasonK @paulinesho