Skip to content

Commit

Permalink
Quantization Verifiers based on T2x set of Traits (#2041)
Browse files Browse the repository at this point in the history
* made `isCompatibleElementTypeForHloTypeInference` stricter to return
error for {not Quantize, Quantize}, {per-axis Quantized, per-tensor
Quantized} cases
* `AddOp` VHLO Test failures : addressed test failures because {not
Quantize, Quantize} is not allowed
* CorrectedTraits for `CholeskyOp` and `ClampOp` to match it with the
spec

~~Note: This PR is based on in review PR
#2007

Follow up PR will add/update OP verifiers for OPs which need special
handling
  • Loading branch information
abhigunj authored Mar 5, 2024
1 parent c5292f6 commit da04b39
Show file tree
Hide file tree
Showing 27 changed files with 100 additions and 78 deletions.
35 changes: 20 additions & 15 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ limitations under the License.
namespace mlir {
namespace hlo {

namespace {
Type getExpressedTypeOrSelf(Type type) {
auto quantType = type.dyn_cast<quant::QuantizedType>();
return quantType ? quantType.getExpressedType() : type;
}
} // namespace

LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2) {
if (failed(verifyCompatibleShape(type1, type2))) return failure();

Expand Down Expand Up @@ -72,20 +65,32 @@ bool isCompatibleElementTypeForHloTypeInference(Type tp1, Type tp2) {
tp1 = getElementTypeOrSelf(tp1);
tp2 = getElementTypeOrSelf(tp2);

// Quantization: In the most general case, we allow any combination of
// quantized/non-quantized across any combination of operands/results,
// and some differences in quantization parameters across operands/results.
// Individual ops may introduce additional constraints.
// For quantized types:
// a. both `tp1` and `tp2` should be quantized types
// b. with similar quantization granularity (i.e. both per-tensor or both
// per-axis)
// c. with equal storage_type, storage_type_min, storage_type_max, and
// expressed_type
auto qtp1 = tp1.dyn_cast<quant::QuantizedType>();
auto qtp2 = tp2.dyn_cast<quant::QuantizedType>();
if (qtp1 && qtp2) {
if (qtp1.getStorageType() != qtp2.getStorageType() ||
qtp1.getStorageTypeMin() != qtp2.getStorageTypeMin() ||
qtp1.getStorageTypeMax() != qtp2.getStorageTypeMax())
qtp1.getStorageTypeMax() != qtp2.getStorageTypeMax() ||
qtp1.getExpressedType() != qtp2.getExpressedType()) {
return false;
}

auto qpatp1 = qtp1.dyn_cast<quant::UniformQuantizedPerAxisType>();
auto qpatp2 = qtp2.dyn_cast<quant::UniformQuantizedPerAxisType>();
bool quantizationGranularityMatches =
(qpatp1 && qpatp2) || (!qpatp1 && !qpatp2);

return quantizationGranularityMatches;
}
auto etp1 = getExpressedTypeOrSelf(tp1);
auto etp2 = getExpressedTypeOrSelf(tp2);

// return false if only one is of quantized type
if (qtp1 || qtp2) return false;

// Sparsity: In the most general case, we allow any combination of
// sparsity/denseness across any combination of operands/results, as well as
Expand All @@ -96,7 +101,7 @@ bool isCompatibleElementTypeForHloTypeInference(Type tp1, Type tp2) {

// Default case: Unless dynamism, quantization and/or sparsity are involved,
// the types are required to be exactly equal.
return etp1 == etp2;
return tp1 == tp2;
}

bool isCompatibleForHloTypeInference(Type tp1, Type tp2) {
Expand Down
23 changes: 23 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,29 @@ class CompatibleOperandsAndResultElementType
}
};

template <typename ConcreteType>
class CompatibleOperandsElementType
: public mlir::OpTrait::TraitBase<ConcreteType,
CompatibleOperandsElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
if (failed(mlir::OpTrait::impl::verifyAtLeastNOperands(op, 1)))
return failure();

Type expected = op->getOperand(0).getType();
auto typeMatch = [&](Type actual) {
return isCompatibleElementTypeForHloTypeInference(actual, expected);
};
auto allMatch = llvm::all_of(op->getOperandTypes(), typeMatch);
if (!allMatch) {
return op->emitOpError(
"requires compatible element types for all operands");
}

return success();
}
};

template <typename ConcreteType>
class CompatibleOperandsAndResultType
: public mlir::OpTrait::TraitBase<ConcreteType,
Expand Down
3 changes: 3 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ def HLO_CompatibleOperandsAndResultType : TraitList<
def HLO_CompatibleOperandsAndResultElementType :
HLO_NativeOpTrait<"CompatibleOperandsAndResultElementType">;

def HLO_CompatibleOperandsElementType :
HLO_NativeOpTrait<"CompatibleOperandsElementType">;

def HLO_BoundedAttrInterface : AttrInterface<"BoundedAttrInterface"> {
let cppNamespace = "::mlir::hlo";

Expand Down
17 changes: 7 additions & 10 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ def StableHLO_TupleOp : StableHLO_Op<"tuple", [Pure,
}

def StableHLO_CompareOp: StableHLO_Op<"compare", [Pure, Elementwise,
SameOperandsElementType /*compare_c1*/,
HLO_CompatibleOperandsElementType /*compare_c1*/,
SameOperandsAndResultShape /*compare_c2*/,
InferTensorTypeWithReify /*compare_c1, compare_c2*/]> {
let summary = "Compare operation";
Expand Down Expand Up @@ -1685,8 +1685,7 @@ def StableHLO_DynamicUpdateSliceOp: StableHLO_Op<"dynamic_update_slice",
//===----------------------------------------------------------------------===//

def StableHLO_BatchNormGradOp : StableHLO_Op<"batch_norm_grad", [Pure,
AllElementTypesMatch<["operand", "scale", "mean", "variance", "grad_output",
"grad_operand", "grad_scale", "grad_offset"] /*batch_norm_grad_c2*/>,
HLO_CompatibleOperandsAndResultElementType /*batch_norm_grad_c2*/,
InferTensorType /*batch_norm_grad_c3, batch_norm_grad_c4*/]> {
let summary = "BatchNormGrad operation";
let description = [{
Expand Down Expand Up @@ -1725,8 +1724,7 @@ def StableHLO_BatchNormGradOp : StableHLO_Op<"batch_norm_grad", [Pure,
}

def StableHLO_BatchNormInferenceOp : StableHLO_Op<"batch_norm_inference",
[Pure, AllElementTypesMatch<["operand", "scale", "offset", "mean",
"variance", "result"]> /*batch_norm_inference_c2*/,
[Pure, HLO_CompatibleOperandsAndResultElementType /*batch_norm_inference_c2*/,
InferTensorType /*batch_norm_inference_c7*/]> {
let summary = "BatchNormInference operation";
let description = [{
Expand Down Expand Up @@ -1759,8 +1757,7 @@ def StableHLO_BatchNormInferenceOp : StableHLO_Op<"batch_norm_inference",
}

def StableHLO_BatchNormTrainingOp : StableHLO_Op<"batch_norm_training",
[Pure, AllElementTypesMatch<["operand", "scale", "offset", "output",
"batch_mean", "batch_var"]> /*batch_norm_training_c2*/,
[Pure, HLO_CompatibleOperandsAndResultElementType /*batch_norm_training_c2*/,
InferTensorType /*batch_norm_training_c5, batch_norm_training_c6, batch_norm_training_c7*/]> {
let summary = "BatchNormTraining operation";
let description = [{
Expand Down Expand Up @@ -1927,7 +1924,7 @@ def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp<
// directly.

def StableHLO_CholeskyOp : StableHLO_Op<"cholesky",
[Pure, SameOperandsAndResultElementType /*cholesky_c1*/,
[Pure, HLO_CompatibleOperandsAndResultElementType /*cholesky_c1*/,
InferTensorType /*cholesky_c1*/]> {
let summary = "Cholesky operation";
let description = [{
Expand All @@ -1954,7 +1951,7 @@ def StableHLO_CholeskyOp : StableHLO_Op<"cholesky",
}

def StableHLO_ClampOp : StableHLO_ShapedInterfaceOp<"clamp", [Pure,
SameOperandsAndResultElementType /* clamp_c3 */, HLO_BroadcastingElementwise,
HLO_CompatibleOperandsAndResultElementType /* clamp_c3 */, HLO_BroadcastingElementwise,
InferTensorType]> {
let summary = "Clamp operation";
let description = [{
Expand Down Expand Up @@ -2814,7 +2811,7 @@ def StableHLO_TransposeOp: StableHLO_ShapedInterfaceOp<"transpose",
}

def StableHLO_TriangularSolveOp: StableHLO_Op<"triangular_solve",
[Pure, SameOperandsAndResultElementType, InferTensorType]> {
[Pure, HLO_CompatibleOperandsAndResultElementType, InferTensorType]> {
let summary = "TriangularSolve operation";
let description = [{
Solves batches of systems of linear equations with lower or upper triangular
Expand Down
10 changes: 2 additions & 8 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5070,16 +5070,10 @@ func.func @is_compatible_dynamism_dim_mismatch(%arg0: tensor<1x?xf32>) {

// -----

// TODO(b/230263270): For stablehlo.add, the plan is to only allow fp+fp=fp, q+q=q and q+q=fp.
func.func @is_compatible_quant_mix_non_quant(%arg0: tensor<1xf32>, %arg1: tensor<1x!quant.uniform<i8:f32, 1.0:17>>) {
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "stablehlo.add"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%2 = "stablehlo.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%4 = "stablehlo.add"(%arg1, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1xf32>) -> tensor<1xf32>
%5 = "stablehlo.add"(%arg1, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1xf32>) -> tensor<1xf32>
%6 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%7 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%1 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%2 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
func.return
}

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/tests/ops_stablehlo_quantized.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func.func @ops_per_tensor_quantization(
%sqrt = "stablehlo.sqrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%subtract = "stablehlo.subtract"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%tanh = "stablehlo.tanh"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%transpose = "stablehlo.transpose"(%arg0) {permutation = array<i64: 0, 2, 1>}: (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:1, {0.1:-30, 0.5:-20}>>
%transpose = "stablehlo.transpose"(%arg0) {permutation = array<i64: 0, 2, 1>}: (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8<-128:127>:f32, 1.0:17>>
%tuple = "stablehlo.tuple"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tuple<tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>>
%uniform_dequantize = "stablehlo.uniform_dequantize" (%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2xf32>
%uniform_quantize = "stablehlo.uniform_quantize" (%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2226,10 +2226,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2233,10 +2233,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2233,10 +2233,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2233,10 +2233,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2233,10 +2233,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2241,10 +2241,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2252,10 +2252,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc
Binary file not shown.
Loading

0 comments on commit da04b39

Please sign in to comment.