-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
stablehlo.uniform_quantize cannot be serialized to bytecode #1812
Comments
cc @GleasonK |
Thanks for reporting! Will take a look early next week. |
Hey @GleasonK, any update on this issue? We're currently using a workaround for this that we'd like to get rid of. |
The issue is related to propagation of signless information for integer storage type during PyTorch to StableHLO export. StableHLO, as bootstrapped from MHLO, inherits signless integers which was added in MHLO for some legacy reasons and treated as signed integer. During PyTorch --> HLO export, During HLO --> Stablehlo, During StableHLO -> VHLO, we do not convert such signed type (cs), following the rational that only signless integers are considered to be signed integer unless #22 is resolved. Note that this issue was not reproduced when we serialize (i.e. geretae the bytecode) using a string representation. That is mainly because My proposal is for PyTorch -> StableHLO export to generate signless integers for storage type. Here is the propotype of the proposal: pytorch/xla#6613 Original hloThe thing ot note here is the
Pretty printed stablehlo
P |
@lsy323 Can you please check if the proposal pytorch/xla#6613 works with your setup before I sent it for PR review. |
Current StablehLO quantized signed int [cs](https://github.com/openxla/stablehlo/blob/4a26ddee5fdbe178a84f219513bfc48f565919b7/stablehlo/dialect/Base.td#L60) uses `mlir::quant::QuantizedType::isSIgned` to decide on the signed-ness of the storage type. StableHLO, as bootstrapped from MHLO, inherits `signless` integers (added in MHLO for some legacy reasons) to be treated as signed integer. This is not captured by the check `mlir::quant::QuantizedType::isSigned` because of the following reason. Based on [this](https://github.com/llvm/llvm-project/blob/16e74fd48988ac95551d0f64e1b36f78a82a89a2/mlir/include/mlir/Dialect/Quant/QuantTypes.h#L102), `mlir::quant::QuantizedType::isSIgned()` is true if the associated bit in [flag](https://github.com/llvm/llvm-project/blob/16e74fd48988ac95551d0f64e1b36f78a82a89a2/mlir/lib/Dialect/Quant/IR/TypeDetail.h#L30) is 1. Here are a few ways to set that flag as true. ``` auto signed_flag_bit = storage_type.cast<mlir::IntegerType>().isSignless(); ... mlir::quant::UniformQuantizedPerAxisType::get( is_signed, storage_type, expressed_type, scales, zero_points, quantization_dimension, storage_min, storage_max); or auto signed_flag_bit = storage_type.cast<mlir::IntegerType>().isSigned(); ... mlir::quant::UniformQuantizedPerAxisType::get( is_signed, storage_type, expressed_type, scales, zero_points, quantization_dimension, storage_min, storage_max); ``` In other words, It is on the producers of `mlir::quant::QuantizedType` to set that bit based on **their interpretation of signedness**, which in case of StableHLO is signedless. That means the a `true` value of `mlir::quant::QuantizedType::isSIgned()` is not enough to ensure that the desired signed-ness of storage type. ## Suggested change Replace ``` CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" # ".isSigned()">]>, ``` with ``` CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" # ".getStorageType().cast<mlir::IntegerType>().isSignless()">]>, ``` IMO, we can skip the check `mlir::quant::UniformQuantizedType::isSigned()` altogether as is it not sufficient nor necessary. Context: #1812 (comment)
Any status update here? Reassigning to Sandeep who will have a better idea on how to triage (or perhaps pytorch/xla#6613 should mark this as fixed on submit)? |
Per suggestion from @lsy323, we will review+land pytorch/xla#6613 once the ci is clean. |
Thanks @sdasgup3! Issue resolved in pytorch/xla#6613 |
What happened?
The MLIR module containing
stablehlo.uniform_quantize/dequantize
ops failed during bytecode serializing with errorHowever, the MLIR module can be serialized to readable format
Steps to reproduce your issue
The StableHLO with uniform_quant/dequant op is generated from PyTorch -> PyTorch/XLA -> StableHLO. To reproduce the bug e2e requires changes in PyTorch/XLA and HLO->StableHLO converter. The change hasn't been merged to PyTorch/XLA head.(But will be merged soon as experimental feature) Please let me if repro e2e is needed.
This function is used to serialize stablehlo bytecode in PyTorch/XLA
Version information
StableHLO commit 46a2506
from openxla/xla: 51b59cfb1999c6f1b3ec59851675044b2c502aae
The text was updated successfully, but these errors were encountered: