Skip to content

Commit

Permalink
Introduce QuantizedType (#1352)
Browse files Browse the repository at this point in the history
StableHLO dialect currently supports quantization via:
  1) Supporting `quant.uniform` element types.
  2) Having dedicated ops like `uniform_quantize` / `uniform_dequantize`.
  3) Allowing regular ops like `add` / `convolution` to take quantized
tensors.

This support was inherited from MHLO when StableHLO was bootstrapped,
and MHLO support was motivated by mobile use cases and inherited from
TFLite.

As pointed out in #1149, StableHLO specification doesn't support
quantization at the moment, and this is an important gap that we would 
like to fix before StableHLO v1.0 (see #588).

To continue the discussion started in #1149 and to make progress towards
v1.0, this pull request:
  A) Adds QuantizedType to the StableHLO specification, modelled after
[TFLite quantization
spec](https://www.tensorflow.org/lite/performance/quantization_spec).
  B) To start a conversation about the applications of QuantizedType and
the semantics of quantized ops, proposes semantics for quantized `add`.

TFLite quantization spec doesn't cover everything. It specs constraints
on types (which we captured accordingly in this pull request), but it
doesn't go into describing semantics of quantized ops.

As a result, the proposed semantics for quantized `add` is intentionally
naive, as compared with the much more involved implementations in the
TensorFlow repository, e.g.:
  *
[tfl.add](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/add.cc).
  *
[tf.UniformQuantizedAdd](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op.cc).

upd: After community discussion, we removed the spec for quantized
`add` leaving that for future work, since further alignment is required.

---------

Co-authored-by: Eugene Burmako <burmako@google.com>
  • Loading branch information
sdasgup3 and Eugene Burmako authored Apr 14, 2023
1 parent c0769c2 commit e83c5e0
Showing 1 changed file with 64 additions and 5 deletions.
69 changes: 64 additions & 5 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ completely numeric to simplify generation of StableHLO programs.

```ebnf
Type ::= ValueType | NonValueType
ValueType ::= TensorType | TokenType | TupleType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= ElementType | FunctionType | StringType
```

Expand Down Expand Up @@ -116,6 +116,69 @@ types, for example, to include layouts
([#629](https://github.com/openxla/stablehlo/issues/629)) and sparsity
([#1078](https://github.com/openxla/stablehlo/issues/1078)).

```ebnf
QuantizedTensorType ::= 'tensor' '<' TensorShape QuantizedElementType '>'
QuantizedElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
```

**Quantized element types** represent integer values of a **storage type** in
the range from `storage_min` to `storage_max` (inclusive) that correspond to
floating-point values of an **expressed type**. For a given integer value `i`,
the corresponding floating-point value `f` can be computed as
`f = (i - zero_point) * scale`, where `scale` and `zero_point` are called
**quantization parameters**. The `storage_min` and `storage_max` are optional
in the grammar, but have default values of `min_value(storage_type)` and
`max_value(storage_type)` respectively. Quantized element types have the
following constraints:

* (C1) `num_bits(storage_type) < num_bits(expressed_type)`.
* (C2) `type(storage_min) = storage_type`.
* (C3) `type(storage_max) = storage_type`.
* (C4) `min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)`.
* (C5) For all `i`, `type(scales[i]) = expressed_type`.
* (C6) For all `i`, `scales[i] > 0`.
* (C7) For all `i`, `is_finite(scales[i])`.
* (C8) For all `i`, `storage_min <= zero_points[i] <= storage_max`.
* (C9) For all `i`, `type(zero_points[i]) = storage_type`.
* (C10) `size(scales) = size(zero_points)`.
* (C11) If `quantization_dimension` is empty, then `size(scales) = 1`.
* (C12) If `quantization_dimension` is not empty, then
`0 <= quantization_dimension`.

**Quantized tensor types** represent tensors with quantized elements. These
tensors are exactly the same as regular tensors, except that their elements
have quantized element types, instead of regular element types.

In quantized tensors, quantization can be **per-tensor**, meaning, having
one `scale` and `zero_point` for the entire tensor or can be **per-axis**,
meaning, having multiple `scales` and `zero_points`, one pair per slice of
a particular dimension `quantized_dimension`. More formally, in a tensor `t` of
with per-axis quantization, there are `dim(t, quantized_dimension)` slices
of the `quantized_dimension`: `t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]`, etc.
All elements in the `i`th slice use `scales[i]` and `zero_points[i]` as their
quantization parameters. Quantized tensor types have the following constraints:

* For per-tensor quantization:
* No additional constraints.
* For per-axis quantization:
* (C12) `quantization_dimension < size(shape)`.
* (C13) `size(scales) = shape[quantization_dimension]`.

```ebnf
TokenType ::= 'token'
```
Expand Down Expand Up @@ -173,10 +236,6 @@ values of type `tensor<T>`).
and an **imaginary part** of the same **element type**. Supported complex
types are `complex<f32>` (both parts are of type `f32`) and `complex<f64>`
(both parts are of type `f64`).
* In the future, we are also planning to introduce **quantized types** that
represent integer values obtained via uniform quantization of floating-point
values using given scales and zero points
([#588](https://github.com/openxla/stablehlo/issues/588)).

```ebnf
FunctionType ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Expand Down

0 comments on commit e83c5e0

Please sign in to comment.