Skip to content

Commit

Permalink
[mlir][Tosa] fix fp16/bf16 support for AvgPool2d (llvm#68718)
Browse files Browse the repository at this point in the history
Currently, the AvgPool2d operation in the TOSA MLIR dialect does not
accept half-precision Fp16 and Bf16 tensors, conversely to what stated
in the [TOSA
specification](https://www.mlplatform.org/tosa/tosa_spec.html#_avg_pool2d).
This issue was previously raised: llvm#63424 here on Github and it is due to
a bug in the AvgPool2d verifier.

This patch fixes the AvgPool2d verifier to accept fp16 & bf16 datatype
for input/output tensors and accumulator, and it adds related LIT test
cases in Tosa/ops.mlir.
  • Loading branch information
fabrizio-indirli authored and ttjost committed Feb 20, 2024
1 parent 7093b19 commit b199e1d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
20 changes: 10 additions & 10 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,20 +184,20 @@ LogicalResult tosa::AvgPool2dOp::verify() {
if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
return emitOpError("accumulator type for integer tensor is not i32");

if ((inputETy.isBF16() || inputETy.isF16()) &&
!(accType.isF16() || accType.isF32()))
return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32");
if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
return emitOpError("accumulator type for f16 tensor is not f16/f32");

if (inputETy.isBF16() && !accType.isF32())
return emitOpError("accumulator type for bf16 tensor is not f32");

if (inputETy.isF32() && !accType.isF32())
return emitOpError("accumulator type for f32 tensor is not f32");

if (inputETy.isF32() && resultETy.isF32())
return success();
if (inputETy.isBF16() && resultETy.isBF16())
return success();
if (inputETy.isInteger(8) && resultETy.isInteger(8))
return success();
if (inputETy.isInteger(16) && resultETy.isInteger(16))
if ((inputETy.isF32() && resultETy.isF32()) ||
(inputETy.isF16() && resultETy.isF16()) ||
(inputETy.isBF16() && resultETy.isBF16()) ||
(inputETy.isInteger(8) && resultETy.isInteger(8)) ||
(inputETy.isInteger(16) && resultETy.isInteger(16)))
return success();

return emitOpError("input/output element types are incompatible.");
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32
return %0 : tensor<1x7x7x9xf32>
}

// -----
// CHECK-LABEL: avg_pool2d_f16
func.func @test_avg_pool2d_f16(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
%0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
return %0 : tensor<1x7x7x9xf16>
}

// -----
// CHECK-LABEL: avg_pool2d_f16_accumf32
func.func @test_avg_pool2d_f16_accumf32(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
return %0 : tensor<1x7x7x9xf16>
}

// -----
// CHECK-LABEL: avg_pool2d_i8
func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> {
Expand Down

0 comments on commit b199e1d

Please sign in to comment.