Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix output type of custom calls while lowering quant/dequant torch op to HLO #6283

Merged
merged 1 commit into from
Jan 11, 2024

Conversation

sdasgup3
Copy link
Collaborator

@sdasgup3 sdasgup3 commented Jan 10, 2024

#5763 allows lowering the torch quantize/dequantize operations to HLO custom calls. For example,
the following PyTorch code

x = torch.ops.quantized_decomposed.quantize_per_tensor(
    x, 0.4, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
    x, 0.4, 2, -128, 127, torch.int8)

is lowered to the following HLO operations:

ENTRY %IrToHlo.5 (p0.1: f32[2,3,4,5]) -> (f32[2,3,4,5]) {
  %p0.1 = f32[2,3,4,5]{3,2,1,0} parameter(0)

  %custom-call.2 = f32[2,3,4,5]{3,2,1,0} custom-call(f32[2,3,4,5]{3,2,1,0} %p0.1), custom_call_target="stablehlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[0.4],zero_point=[2],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}

  %custom-call.3 = f32[2,3,4,5]{3,2,1,0} custom-call(f32[2,3,4,5]{3,2,1,0} %custom-call.2), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[0.4],zero_point=[2],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
  ROOT %tuple.4 = (f32[2,3,4,5]{3,2,1,0}) tuple(f32[2,3,4,5]{3,2,1,0} %custom-call.3)
}

Note that the output of custom call corresponding to quantize op has element type f32. The fact that the output of custom_call (for quantize operation) is a float is more of a logical problem as the the result of quantization is generally expected to be in integer domain. Also, note that the choice of output type should not effect the eventual conversion of HLO custom calls to mhlo uniform.quantize/uniform.dequantize operations.

Moreover, based on https://github.com/pytorch/pytorch/blob/0b72ce1bd1a4a0596dde4053899b8a9a7999bc47/torch/ao/quantization/fx/_decomposed.py#L164 we set the output element type of dequatize operation to be f32

Finally improved the debuggability of the map queries using proper error messages.

cc @GleasonK

@sdasgup3 sdasgup3 requested a review from lsy323 January 10, 2024 01:30
torch_xla/csrc/ops/dequant_tensor.cpp Show resolved Hide resolved
@@ -192,4 +192,16 @@ GetHloDtypeToStablehloDtypeMap() {
return m_;
}

const std::unordered_map<std::string, xla::PrimitiveType>&
GetTorchIntDtypeToHloDtypeMap() {
static const std::unordered_map<std::string, xla::PrimitiveType> m_{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could implement this as a switch / series of if stmts to avoid a static dictionary, it's a short enough list. Not sure of PT/XLAs stance on static data.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree about less static data floating around. Replaced the dictionary with conditionals.
@lsy323 If the change LG to you then I can merge once the CI is green.

@sdasgup3 sdasgup3 force-pushed the sdasgup3/fix-output-type-of-quantize-op branch 2 times, most recently from 8ad8b29 to f970401 Compare January 10, 2024 19:16
@sdasgup3 sdasgup3 force-pushed the sdasgup3/fix-output-type-of-quantize-op branch from f970401 to 9da267d Compare January 10, 2024 21:45
@lsy323 lsy323 merged commit 68f4750 into master Jan 11, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants