Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stablehlo.uniform_quantize cannot be serialized to bytecode #1812

Closed
lsy323 opened this issue Oct 18, 2023 · 8 comments
Closed

stablehlo.uniform_quantize cannot be serialized to bytecode #1812

lsy323 opened this issue Oct 18, 2023 · 8 comments
Assignees

Comments

@lsy323
Copy link

lsy323 commented Oct 18, 2023

What happened?

The MLIR module containing stablehlo.uniform_quantize/dequantize ops failed during bytecode serializing with error

loc("custom-call.6"): error: failed to legalize operation 'stablehlo.uniform_quantize' that was explicitly marked illegal

However, the MLIR module can be serialized to readable format

module @IrToHlo.18 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10x3x3x3xf32>, %arg2: tensor<1x3x10x10xf32>) -> tensor<1x10x8x8xf32> {
    %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x10x8x8xf32>
    %1 = stablehlo.uniform_quantize %arg2 : (tensor<1x3x10x10xf32>) -> tensor<1x3x10x10x!quant.uniform<i8:f32, 1.000000e+00>>
    %2 = stablehlo.uniform_dequantize %1 : (tensor<1x3x10x10x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x10x10xf32>
    %3 = stablehlo.uniform_quantize %arg1 : (tensor<10x3x3x3xf32>) -> tensor<10x3x3x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>
    %4 = stablehlo.uniform_dequantize %3 : (tensor<10x3x3x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<10x3x3x3xf32>
    %5 = stablehlo.convolution(%2, %4) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<1x3x10x10xf32>, tensor<10x3x3x3xf32>) -> tensor<1x10x8x8xf32>
    %6 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<10xf32>) -> tensor<1x10x8x8xf32>
    %7 = stablehlo.add %5, %6 : tensor<1x10x8x8xf32>
    %8 = stablehlo.maximum %7, %0 : tensor<1x10x8x8xf32>
    %9 = stablehlo.uniform_quantize %8 : (tensor<1x10x8x8xf32>) -> tensor<1x10x8x8x!quant.uniform<i8:f32, 1.000000e+00>>
    %10 = stablehlo.uniform_dequantize %9 : (tensor<1x10x8x8x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x10x8x8xf32>
    return %10 : tensor<1x10x8x8xf32>
  }
}

Steps to reproduce your issue

The StableHLO with uniform_quant/dequant op is generated from PyTorch -> PyTorch/XLA -> StableHLO. To reproduce the bug e2e requires changes in PyTorch/XLA and HLO->StableHLO converter. The change hasn't been merged to PyTorch/XLA head.(But will be merged soon as experimental feature) Please let me if repro e2e is needed.

This function is used to serialize stablehlo bytecode in PyTorch/XLA

Version information

StableHLO commit 46a2506

from openxla/xla: 51b59cfb1999c6f1b3ec59851675044b2c502aae

@lsy323
Copy link
Author

lsy323 commented Oct 18, 2023

cc @GleasonK

@GleasonK
Copy link
Member

Thanks for reporting! Will take a look early next week.

@arfaian
Copy link

arfaian commented Feb 20, 2024

Hey @GleasonK, any update on this issue? We're currently using a workaround for this that we'd like to get rid of.

@sdasgup3
Copy link
Member

sdasgup3 commented Feb 26, 2024

The issue is related to propagation of signless information for integer storage type during PyTorch to StableHLO export.

StableHLO, as bootstrapped from MHLO, inherits signless integers which was added in MHLO for some legacy reasons and treated as signed integer.

During PyTorch --> HLO export,
we create signed integer (e.g. si8) (from here and here) as storage type and store in in the attribute string dictionary.

During HLO --> Stablehlo,
we create a signed integer as the storage type of UniformQuantizedType (here).

During StableHLO -> VHLO, we do not convert such signed type (cs), following the rational that only signless integers are considered to be signed integer unless #22 is resolved.

Note that this issue was not reproduced when we serialize (i.e. geretae the bytecode) using a string representation. That is mainly because si8 in a Stablehlo for in-memory module representation is pretty printed as i8 which can be serialized w/o any issue.

My proposal is for PyTorch -> StableHLO export to generate signless integers for storage type. Here is the propotype of the proposal: pytorch/xla#6613

Original hlo

The thing ot note here is the si8 storage type which got normalized to i8 when the stablehlo is printed (the next snippet)

HloModule IrToHlo.18, entry_computation_layout={(f32[10]{0}, f32[10,5]{1,0}, f32[3,5]{1,0})->(f32[3,10]{1,0})}

ENTRY %IrToHlo.18 (p0.1: f32[10], p1.2: f32[10,5], p2.6: f32[3,5]) -> (f32[3,10]) {
  %p2.6 = f32[3,5]{1,0} parameter(2)
  %custom-call.7 = s8[3,5]{1,0} custom-call(f32[3,5]{1,0} %p2.6), custom_call_target="stablehlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
  %custom-call.8 = f32[3,5]{1,0} custom-call(s8[3,5]{1,0} %custom-call.7), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
  %p1.2 = f32[10,5]{1,0} parameter(1)
  %custom-call.3 = s8[10,5]{1,0} custom-call(f32[10,5]{1,0} %p1.2), custom_call_target="stablehlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-127,storage_max=127}
  %custom-call.4 = f32[10,5]{1,0} custom-call(s8[10,5]{1,0} %custom-call.3), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-127,storage_max=127}
  %transpose.5 = f32[5,10]{0,1} transpose(f32[10,5]{1,0} %custom-call.4), dimensions={1,0}
  %dot.9 = f32[3,10]{1,0} dot(f32[3,5]{1,0} %custom-call.8, f32[5,10]{0,1} %transpose.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  %p0.1 = f32[10]{0} parameter(0)
  %reshape.10 = f32[1,10]{1,0} reshape(f32[10]{0} %p0.1)
  %broadcast.11 = f32[1,10]{1,0} broadcast(f32[1,10]{1,0} %reshape.10), dimensions={0,1}
  %reshape.12 = f32[10]{0} reshape(f32[1,10]{1,0} %broadcast.11)
  %broadcast.13 = f32[3,10]{1,0} broadcast(f32[10]{0} %reshape.12), dimensions={1}
  %add.14 = f32[3,10]{1,0} add(f32[3,10]{1,0} %dot.9, f32[3,10]{1,0} %broadcast.13)
  %custom-call.15 = s8[3,10]{1,0} custom-call(f32[3,10]{1,0} %add.14), custom_call_target="stablehlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
  %custom-call.16 = f32[3,10]{1,0} custom-call(s8[3,10]{1,0} %custom-call.15), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[1.00],zero_point=[0],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
  ROOT %tuple.17 = (f32[3,10]{1,0}) tuple(f32[3,10]{1,0} %custom-call.16)
}

Pretty printed stablehlo

module @IrToHlo.18 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10x5xf32>, %arg2: tensor<3x5xf32>) -> tensor<3x10xf32> {
    %0 = stablehlo.uniform_quantize %arg2 : (tensor<3x5xf32>) -> tensor<3x5x!quant.uniform<i8:f32, 1.000000e+00>>
    %1 = stablehlo.uniform_dequantize %0 : (tensor<3x5x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<3x5xf32>
    %2 = stablehlo.uniform_quantize %arg1 : (tensor<10x5xf32>) -> tensor<10x5x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>
    %3 = stablehlo.uniform_dequantize %2 : (tensor<10x5x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<10x5xf32>
    %4 = stablehlo.transpose %3, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[5,10]{0,1}"} : (tensor<10x5xf32>) -> tensor<5x10xf32>
    %5 = stablehlo.dot %1, %4, precision = [DEFAULT, DEFAULT] : (tensor<3x5xf32>, tensor<5x10xf32>) -> tensor<3x10xf32>
    %6 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<10xf32>) -> tensor<3x10xf32>
    %7 = stablehlo.add %5, %6 : tensor<3x10xf32>
    %8 = stablehlo.uniform_quantize %7 : (tensor<3x10xf32>) -> tensor<3x10x!quant.uniform<i8:f32, 1.000000e+00>>
    %9 = stablehlo.uniform_dequantize %8 : (tensor<3x10x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<3x10xf32>
    return %9 : tensor<3x10xf32>
  }
}

P

@sdasgup3
Copy link
Member

@lsy323 Can you please check if the proposal pytorch/xla#6613 works with your setup before I sent it for PR review.

GleasonK pushed a commit that referenced this issue Feb 28, 2024
Current StablehLO quantized signed int
[cs](https://github.com/openxla/stablehlo/blob/4a26ddee5fdbe178a84f219513bfc48f565919b7/stablehlo/dialect/Base.td#L60)
uses `mlir::quant::QuantizedType::isSIgned` to decide on the signed-ness
of the storage type.

StableHLO, as bootstrapped from MHLO, inherits `signless` integers
(added in MHLO for some legacy reasons) to be treated as signed integer.
This is not captured by the check `mlir::quant::QuantizedType::isSigned`
because of the following reason.

 
Based on
[this](https://github.com/llvm/llvm-project/blob/16e74fd48988ac95551d0f64e1b36f78a82a89a2/mlir/include/mlir/Dialect/Quant/QuantTypes.h#L102),
`mlir::quant::QuantizedType::isSIgned()` is true if the associated bit
in
[flag](https://github.com/llvm/llvm-project/blob/16e74fd48988ac95551d0f64e1b36f78a82a89a2/mlir/lib/Dialect/Quant/IR/TypeDetail.h#L30)
is 1. Here are a few ways to set that flag as true.

```
 auto signed_flag_bit = storage_type.cast<mlir::IntegerType>().isSignless();
... mlir::quant::UniformQuantizedPerAxisType::get(
     is_signed, storage_type, expressed_type, scales, zero_points,
     quantization_dimension, storage_min, storage_max);

or

 auto signed_flag_bit = storage_type.cast<mlir::IntegerType>().isSigned();
... mlir::quant::UniformQuantizedPerAxisType::get(
     is_signed, storage_type, expressed_type, scales, zero_points,
     quantization_dimension, storage_min, storage_max);
```

In other words, It is on the producers of `mlir::quant::QuantizedType`
to set that bit based on **their interpretation of signedness**, which
in case of StableHLO is signedless. That means the a `true` value of
`mlir::quant::QuantizedType::isSIgned()` is not enough to ensure that
the desired signed-ness of storage type.
 
## Suggested change

Replace

```
CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" #
                 ".isSigned()">]>,
```

with 

```
           CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" #
                 ".getStorageType().cast<mlir::IntegerType>().isSignless()">]>,

```

IMO, we can skip the check
`mlir::quant::UniformQuantizedType::isSigned()` altogether as is it not
sufficient nor necessary.

Context:
#1812 (comment)
@GleasonK GleasonK assigned sdasgup3 and unassigned GleasonK Mar 4, 2024
@GleasonK
Copy link
Member

GleasonK commented Mar 4, 2024

Any status update here? Reassigning to Sandeep who will have a better idea on how to triage (or perhaps pytorch/xla#6613 should mark this as fixed on submit)?

@sdasgup3
Copy link
Member

sdasgup3 commented Mar 5, 2024

Per suggestion from @lsy323, we will review+land pytorch/xla#6613 once the ci is clean.

@lsy323
Copy link
Author

lsy323 commented Mar 5, 2024

Thanks @sdasgup3! Issue resolved in pytorch/xla#6613

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants