Skip to content

Commit

Permalink
fix lit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
AlexandreEichenberger committed Sep 25, 2024
1 parent 0dffbc9 commit 34b97da
Show file tree
Hide file tree
Showing 6 changed files with 452 additions and 505 deletions.
16 changes: 13 additions & 3 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,19 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
llvm_unreachable("unsupported input type");
IntegerType quantizedIntType = cast<IntegerType>(quantizedElementType);
bool isSigned = quantizedIntType.isSignless() || quantizedIntType.isSigned();
Type quantizedElementTypeInputSized =
rewriter.getIntegerType(inputWidth, isSigned);

Type quantizedElementTypeInputSized;
if (isSigned) {
// Cannot use getIntegerType(inputWidth, true) as it returns signed ints.
if (inputWidth == 64)
quantizedElementTypeInputSized = rewriter.getI64Type();
else if (inputWidth == 32)
quantizedElementTypeInputSized = rewriter.getI32Type();
else
llvm_unreachable("unsupported input type");
} else {
// unsigned of the right type
quantizedElementTypeInputSized = rewriter.getIntegerType(inputWidth, false);
}
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor<?x2xf32>) -> (tensor<?x2xu
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK-DAG: [[VAR_dim_9_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref<?x2xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){
// CHECK: [[VAR_32_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_]]#0, [[VAR_32_]]#1] : memref<?x2xf32>
// CHECK: [[VAR_31_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_]]#0, [[VAR_31_]]#1] : memref<?x2xf32>
// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref<f32>
// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32
// CHECK: krnl.store [[VAR_35_]], [[RES_3_]][] : memref<f32>
// CHECK: [[VAR_34_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32
// CHECK: krnl.store [[VAR_34_]], [[RES_3_]][] : memref<f32>
// CHECK: }
// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref<f32>
// CHECK: krnl.memset [[RES_4_]], [[CST_0_]] : memref<f32>
// CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2
// CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref<?x2xf32>
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){
// CHECK: [[VAR_32_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_1_]]#0, [[VAR_32_1_]]#1] : memref<?x2xf32>
// CHECK: [[VAR_31_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_1_]]#0, [[VAR_31_1_]]#1] : memref<?x2xf32>
// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref<f32>
// CHECK: [[VAR_35_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32
// CHECK: krnl.store [[VAR_35_1_]], [[RES_4_]][] : memref<f32>
// CHECK: [[VAR_34_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32
// CHECK: krnl.store [[VAR_34_1_]], [[RES_4_]][] : memref<f32>
// CHECK: }
// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref<f32>
// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref<f32>
Expand Down Expand Up @@ -87,40 +87,34 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor<?x2xf32>) -> (tensor<?x2xu
// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK: affine.store [[VAR_29_]], [[RES_6_]][0] : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref<?x2xui8>, memref<1xindex>) -> memref<?xui8>
// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc([[VAR_28_]]) {{.*}}: memref<?xf32>
// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){
// CHECK: [[VAR_32_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_32_2_]]{{.}} : memref<?xf32>
// CHECK: [[VAR_31_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_31_2_]]{{.}} : memref<?xf32>
// CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32
// CHECK: [[VAR_35_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32
// CHECK: [[VAR_36_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_35_2_]] : f32
// CHECK-DAG: [[VAR_37_:%.+]] = arith.cmpf ogt, [[VAR_36_]], [[CST_5_dot_000000_]] : f32
// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32
// CHECK: [[VAR_34_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32
// CHECK: [[VAR_35_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_34_2_]] : f32
// CHECK-DAG: [[VAR_36_:%.+]] = arith.cmpf ogt, [[VAR_35_]], [[CST_5_dot_000000_]] : f32
// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_39_:%.+]] = arith.select [[VAR_37_]], [[VAR_38_]], [[VAR_35_2_]] : f32
// CHECK-DAG: [[VAR_40_:%.+]] = arith.mulf [[VAR_35_2_]], [[CST_5_dot_000000_]] : f32
// CHECK: [[VAR_41_:%.+]] = math.floor [[VAR_40_]] : f32
// CHECK: [[VAR_42_:%.+]] = arith.mulf [[VAR_41_]], [[CST_2_dot_000000_]] : f32
// CHECK: [[VAR_43_:%.+]] = arith.subf [[VAR_35_2_]], [[VAR_42_]] : f32
// CHECK-DAG: [[VAR_44_:%.+]] = arith.cmpf oeq, [[VAR_43_]], [[CST_1_dot_000000_]] : f32
// CHECK-DAG: [[VAR_45_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32
// CHECK-DAG: [[VAR_38_:%.+]] = arith.select [[VAR_36_]], [[VAR_37_]], [[VAR_34_2_]] : f32
// CHECK-DAG: [[VAR_39_:%.+]] = arith.mulf [[VAR_34_2_]], [[CST_5_dot_000000_]] : f32
// CHECK: [[VAR_40_:%.+]] = math.floor [[VAR_39_]] : f32
// CHECK: [[VAR_41_:%.+]] = arith.mulf [[VAR_40_]], [[CST_2_dot_000000_]] : f32
// CHECK: [[VAR_42_:%.+]] = arith.subf [[VAR_34_2_]], [[VAR_41_]] : f32
// CHECK-DAG: [[VAR_43_:%.+]] = arith.cmpf oeq, [[VAR_42_]], [[CST_1_dot_000000_]] : f32
// CHECK-DAG: [[VAR_44_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_46_:%.+]] = arith.select [[VAR_44_]], [[VAR_45_]], [[VAR_35_2_]] : f32
// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_36_]], [[CST_5_dot_000000_]] : f32
// CHECK: [[VAR_48_:%.+]] = arith.select [[VAR_47_]], [[VAR_46_]], [[VAR_39_]] : f32
// CHECK: [[VAR_49_:%.+]] = arith.addf [[VAR_48_]], [[VAR_25_]] : f32
// CHECK: [[VAR_50_:%.+]] = arith.maxnumf [[VAR_49_]], [[CST_0_dot_000000_]] : f32
// CHECK: [[VAR_51_:%.+]] = arith.minnumf [[VAR_50_]], [[CST_2_dot_550000_]] : f32
// CHECK: krnl.store [[VAR_51_]], [[RES_7_]]{{.}}[[VAR_32_2_]]{{.}} : memref<?xf32>
// CHECK: }
// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){
// CHECK: [[VAR_32_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_32_3_]]{{.}} : memref<?xf32>
// CHECK: [[LOAD_RES_3_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_1_]] : f32 to i8
// CHECK: [[VAR_35_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_RES_3_MEM_1_1_]] : i8 to ui8
// CHECK: krnl.store [[VAR_35_3_]], [[VAR_reshape_14_]]{{.}}[[VAR_32_3_]]{{.}} : memref<?xui8>
// CHECK-DAG: [[VAR_45_:%.+]] = arith.select [[VAR_43_]], [[VAR_44_]], [[VAR_34_2_]] : f32
// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_35_]], [[CST_5_dot_000000_]] : f32
// CHECK: [[VAR_47_:%.+]] = arith.select [[VAR_46_]], [[VAR_45_]], [[VAR_38_]] : f32
// CHECK: [[VAR_48_:%.+]] = arith.addf [[VAR_47_]], [[VAR_25_]] : f32
// CHECK: [[VAR_49_:%.+]] = arith.maxnumf [[VAR_48_]], [[CST_0_dot_000000_]] : f32
// CHECK: [[VAR_50_:%.+]] = arith.minnumf [[VAR_49_]], [[CST_2_dot_550000_]] : f32
// CHECK: [[VAR_51_:%.+]] = arith.fptoui [[VAR_50_]] : f32 to i32
// CHECK: [[VAR_52_:%.+]] = arith.trunci [[VAR_51_]] : i32 to i8
// CHECK: [[VAR_53_:%.+]] = builtin.unrealized_conversion_cast [[VAR_52_]] : i8 to ui8
// CHECK: krnl.store [[VAR_53_]], [[VAR_reshape_14_]]{{.}}[[VAR_31_2_]]{{.}} : memref<?xui8>
// CHECK: }
// CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref<?x2xui8>, memref<f32>, memref<ui8>
// CHECK: }
Expand Down
Loading

0 comments on commit 34b97da

Please sign in to comment.