diff --git a/examples/DIPDialect/dip.mlir b/examples/DIPDialect/dip.mlir index 4293d0a59..b1b722054 100644 --- a/examples/DIPDialect/dip.mlir +++ b/examples/DIPDialect/dip.mlir @@ -1,12 +1,12 @@ func.func @corr_2d_constant_padding(%inputImage : memref, %kernel : memref, %outputImage : memref, %centerX : index, %centerY : index, %constantValue : f32) { - dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref, memref, memref, index, index, f32 + dip.corr_2d %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref, memref, memref, index, index, f32 return } func.func @corr_2d_replicate_padding(%inputImage : memref, %kernel : memref, %outputImage : memref, %centerX : index, %centerY : index, %constantValue : f32) { - dip.corr_2d REPLICATE_PADDING %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref, memref, memref, index, index, f32 + dip.corr_2d %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref, memref, memref, index, index, f32 return } diff --git a/include/Dialect/DIP/DIPDialect.td b/include/Dialect/DIP/DIPDialect.td index a4e129baa..f2026ff5b 100644 --- a/include/Dialect/DIP/DIPDialect.td +++ b/include/Dialect/DIP/DIPDialect.td @@ -35,6 +35,7 @@ def DIP_Dialect : Dialect { of developing a MLIR backend for performing image processing operations such as 2D Correlation, Morphological processing, etc. }]; + let useDefaultAttributePrinterParser = 1; let cppNamespace = "::buddy::dip"; } diff --git a/include/Dialect/DIP/DIPOps.td b/include/Dialect/DIP/DIPOps.td index e719627bb..b993c5d0f 100644 --- a/include/Dialect/DIP/DIPOps.td +++ b/include/Dialect/DIP/DIPOps.td @@ -55,30 +55,31 @@ def DIP_InterpolationType : I32EnumAttr<"InterpolationType", let cppNamespace = "::buddy::dip"; } -def DIP_BoundaryOptionAttr : EnumAttr; +def DIP_BoundaryOptionAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} def DIP_InterpolationAttr : EnumAttr; -def DIP_Corr2DOp : DIP_Op<"corr_2d"> -{ +def DIP_Corr2DOp : DIP_Op<"corr_2d"> { let summary = [{This operation is used for performing 2D correlation on an image. - The 2D correlation API provided by the linalg dialect is more suited for - applications in which boundary extrapolation is not explicitly required. - Due to this, dimensions of output are always less than the input dimensions after - using linalg dialect's 2D correlation API. - - dip.corr_2d performs boundary extrapolation for making the size of the output image - equal to the size of the input image. Boundary extrapolation can be done using - different methods, supported options are : - a. Constant Padding : Uses a constant for padding whole extra region in input image - for obtaining the boundary extrapolated output image. (kkk|abcdefg|kkk) - b. Replicate Padding : Uses last/first element of respective column/row for padding - the extra region used for creating the boundary extrapolated output image. (aaa|abcdefg|ggg) - For example: + The 2D correlation API provided by the linalg dialect is more suited for + applications in which boundary extrapolation is not explicitly required. + Due to this, dimensions of output are always less than the input dimensions after + using linalg dialect's 2D correlation API. + + dip.corr_2d performs boundary extrapolation for making the size of the output image + equal to the size of the input image. Boundary extrapolation can be done using + different methods, supported options are: + a. Constant Padding : Uses a constant for padding whole extra region in input image + for obtaining the boundary extrapolated output image. (kkk|abcdefg|kkk) + b. Replicate Padding : Uses last/first element of respective column/row for padding + the extra region used for creating the boundary extrapolated output image. (aaa|abcdefg|ggg) + For example: - ```mlir - dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue - : memref, memref, memref, index, index, index - ``` + ```mlir + dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue + : memref, memref, memref, index, index, index + ``` }]; let arguments = (ins Arg [MemRead]>:$memrefK, Arg:$memrefCO, - Index : $centerX, Index : $centerY, F32 : $constantValue, + Index : $centerX, + Index : $centerY, + AnyTypeOf<[AnyI8, AnyI32, AnyI64, AnyFloat]> : $constantValue, DIP_BoundaryOptionAttr:$boundary_option); let assemblyFormat = [{ diff --git a/include/Utils/DIPUtils.h b/include/Utils/DIPUtils.h index 74b6712fa..2c5369fe6 100644 --- a/include/Utils/DIPUtils.h +++ b/include/Utils/DIPUtils.h @@ -24,6 +24,44 @@ #include "Utils/Utils.h" +// Inserts a constant op with value 0 into a location `loc` based on type +// `type`. Supported types are : f32, f64, integer types +Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc, + Type elemTy) { + Value op = {}; + auto bitWidth = elemTy.getIntOrFloatBitWidth(); + if (elemTy.isF32() || elemTy.isF64()) { + FloatType type = + elemTy.isF32() ? FloatType::getF32(ctx) : FloatType::getF64(ctx); + auto zero = APFloat::getZero(type.getFloatSemantics()); + op = builder.create(loc, zero, type); + } else if (elemTy.isInteger(bitWidth)) { + IntegerType type = IntegerType::get(ctx, bitWidth); + op = builder.create(loc, 0, type); + } + + return op; +} + +// Inserts FMA operation into a given location `loc` based on type `type`. +// Note: FMA is done by Multiply and Add for integer types, because there is no +// dedicated FMA operation for them. +// Supported types: f32, f64, integer types +Value insertFMAOp(OpBuilder &builder, Location loc, VectorType type, + Value inputVec, Value kernelVec, Value outputVec) { + Value res = {}; + auto elemTy = type.getElementType(); + auto bitWidth = elemTy.getIntOrFloatBitWidth(); + if (elemTy.isF32() || elemTy.isF64()) { + res = builder.create(loc, inputVec, kernelVec, outputVec); + } else if (elemTy.isInteger(bitWidth)) { + Value mul = builder.create(loc, inputVec, kernelVec); + res = builder.create(loc, mul, outputVec); + } + + return res; +} + // Calculate result of FMA and store it in output memref. This function cannot // handle tail processing. void calcAndStoreFMAwoTailProcessing(OpBuilder &builder, Location loc, @@ -32,7 +70,8 @@ void calcAndStoreFMAwoTailProcessing(OpBuilder &builder, Location loc, Value beginIdx, Value endIdx) { Value outputVec = builder.create(loc, vecType, output, ValueRange{beginIdx, endIdx}); - Value resVec = builder.create(loc, inputVec, kernelVec, outputVec); + Value resVec = + insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec); builder.create(loc, resVec, output, ValueRange{beginIdx, endIdx}); } @@ -72,7 +111,7 @@ void calcAndStoreFMAwTailProcessing(OpBuilder &builder, Location loc, Value outputVec = builder.create(loc, vecType, output, ValueRange{beginIdx, endIdx}); Value resVec = - builder.create(loc, inputVec, kernelVec, outputVec); + insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec); builder.create(loc, resVec, output, ValueRange{beginIdx, endIdx}); @@ -85,7 +124,7 @@ void calcAndStoreFMAwTailProcessing(OpBuilder &builder, Location loc, loc, vecType, output, ValueRange{beginIdx, endIdx}, extraElemMask, zeroPadding); Value resVec = - builder.create(loc, inputVec, kernelVec, outputVec); + insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec); builder.create(loc, output, ValueRange{beginIdx, endIdx}, extraElemMask, resVec); diff --git a/lib/Conversion/LowerDIP/LowerDIPPass.cpp b/lib/Conversion/LowerDIP/LowerDIPPass.cpp index c64cdbc05..c2a25343e 100644 --- a/lib/Conversion/LowerDIP/LowerDIPPass.cpp +++ b/lib/Conversion/LowerDIP/LowerDIPPass.cpp @@ -25,6 +25,10 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "DIP/DIPDialect.h" @@ -56,7 +60,7 @@ class DIPCorr2DOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(dip::Corr2DOp op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto ctx = op->getContext(); + auto *ctx = op->getContext(); // Create constant indices. Value c0 = rewriter.create(loc, 0); @@ -72,7 +76,24 @@ class DIPCorr2DOpLowering : public OpRewritePattern { auto boundaryOptionAttr = op.boundary_option(); Value strideVal = rewriter.create(loc, stride); - FloatType f32 = FloatType::getF32(ctx); + auto inElemTy = input.getType().cast().getElementType(); + auto kElemTy = kernel.getType().cast().getElementType(); + auto outElemTy = output.getType().cast().getElementType(); + auto constElemTy = constantValue.getType(); + if (inElemTy != kElemTy || kElemTy != outElemTy || + outElemTy != constElemTy) { + return op->emitOpError() << "input, kernel, output and constant must " + "have the same element type"; + } + // NB: we can infer element type for all operation to be the same as input + // since we verified that the operand types are the same + auto elemTy = inElemTy; + auto bitWidth = elemTy.getIntOrFloatBitWidth(); + if (!elemTy.isF64() && !elemTy.isF32() && !elemTy.isInteger(bitWidth)) { + return op->emitOpError() << "supports only f32, f64 and integer types. " + << elemTy << "is passed"; + } + IntegerType i1 = IntegerType::get(ctx, 1); // Create DimOp. @@ -90,11 +111,10 @@ class DIPCorr2DOpLowering : public OpRewritePattern { kernelSize}; SmallVector steps{1, 1, stride, 1}; - VectorType vectorTy32 = VectorType::get({stride}, f32); + VectorType vectorTy32 = VectorType::get({stride}, elemTy); VectorType vectorMaskTy = VectorType::get({stride}, i1); - Value zeroPaddingElem = - rewriter.create(loc, (APFloat)(float)0, f32); + Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, elemTy); Value zeroPadding = rewriter.create(loc, vectorTy32, zeroPaddingElem); diff --git a/tests/Dialect/DIP/correlation2D_f32.mlir b/tests/Dialect/DIP/correlation2D_f32.mlir new file mode 100644 index 000000000..f9b53474c --- /dev/null +++ b/tests/Dialect/DIP/correlation2D_f32.mlir @@ -0,0 +1,42 @@ +// +// x86 +// +// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \ +// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +memref.global "private" @global_input : memref<3x3xf32> = dense<[[0. , 1. , 2. ], + [10., 11., 12.], + [20., 21., 22.]]> + +memref.global "private" @global_identity : memref<3x3xf32> = dense<[[0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.]]> + +memref.global "private" @global_output : memref<3x3xf32> = dense<[[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]> +func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func.func @main() -> i32 { + %input = memref.get_global @global_input : memref<3x3xf32> + %identity = memref.get_global @global_identity : memref<3x3xf32> + %output = memref.get_global @global_output : memref<3x3xf32> + + %kernelAnchorX = arith.constant 1 : index + %kernelAnchorY = arith.constant 1 : index + %c = arith.constant 0. : f32 + dip.corr_2d %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, index, index, f32 + + %printed_output = memref.cast %output : memref<3x3xf32> to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + // CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}} + // CHECK{LITERAL}: [[0, 1, 2], + // CHECK{LITERAL}: [10, 11, 12], + // CHECK{LITERAL}: [20, 21, 22]] + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/tests/Dialect/DIP/correlation2D_f64.mlir b/tests/Dialect/DIP/correlation2D_f64.mlir new file mode 100644 index 000000000..c9cdbe31e --- /dev/null +++ b/tests/Dialect/DIP/correlation2D_f64.mlir @@ -0,0 +1,42 @@ +// +// x86 +// +// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \ +// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +memref.global "private" @global_input : memref<3x3xf64> = dense<[[0. , 1. , 2. ], + [10., 11., 12.], + [20., 21., 22.]]> + +memref.global "private" @global_identity : memref<3x3xf64> = dense<[[0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.]]> + +memref.global "private" @global_output : memref<3x3xf64> = dense<[[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]> +func.func private @printMemrefF64(memref<*xf64>) attributes { llvm.emit_c_interface } + +func.func @main() -> i32 { + %input = memref.get_global @global_input : memref<3x3xf64> + %identity = memref.get_global @global_identity : memref<3x3xf64> + %output = memref.get_global @global_output : memref<3x3xf64> + + %kernelAnchorX = arith.constant 1 : index + %kernelAnchorY = arith.constant 1 : index + %c = arith.constant 0. : f64 + dip.corr_2d %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xf64>, memref<3x3xf64>, memref<3x3xf64>, index, index, f64 + + %printed_output = memref.cast %output : memref<3x3xf64> to memref<*xf64> + call @printMemrefF64(%printed_output) : (memref<*xf64>) -> () + // CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}} + // CHECK{LITERAL}: [[0, 1, 2], + // CHECK{LITERAL}: [10, 11, 12], + // CHECK{LITERAL}: [20, 21, 22]] + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/tests/Dialect/DIP/correlation2D_i32.mlir b/tests/Dialect/DIP/correlation2D_i32.mlir new file mode 100644 index 000000000..831ce3b0e --- /dev/null +++ b/tests/Dialect/DIP/correlation2D_i32.mlir @@ -0,0 +1,41 @@ +// +// x86 +// +// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \ +// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +memref.global "private" @global_input : memref<3x3xi32> = dense<[[0 , 1 , 2 ], + [10, 11, 12], + [20, 21, 22]]> + +memref.global "private" @global_identity : memref<3x3xi32> = dense<[[0, 0, 0], + [0, 1, 0], + [0, 0, 0]]> + +memref.global "private" @global_output : memref<3x3xi32> = dense<[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]> +func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface } + +func.func @main() -> i32 { + %input = memref.get_global @global_input : memref<3x3xi32> + %identity = memref.get_global @global_identity : memref<3x3xi32> + %output = memref.get_global @global_output: memref<3x3xi32> + + %kernelAnchorX = arith.constant 1 : index + %kernelAnchorY = arith.constant 1 : index + %c = arith.constant 0 : i32 + dip.corr_2d %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xi32>, memref<3x3xi32>, memref<3x3xi32>, index, index, i32 + + %printed_output = memref.cast %output : memref<3x3xi32> to memref<*xi32> + call @printMemrefI32(%printed_output) : (memref<*xi32>) -> () + // CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}} + // CHECK{LITERAL}: [[0, 1, 2], + // CHECK{LITERAL}: [10, 11, 12], + // CHECK{LITERAL}: [20, 21, 22]] + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/tests/Dialect/DIP/correlation2D_i64.mlir b/tests/Dialect/DIP/correlation2D_i64.mlir new file mode 100644 index 000000000..46acf0b5f --- /dev/null +++ b/tests/Dialect/DIP/correlation2D_i64.mlir @@ -0,0 +1,42 @@ +// +// x86 +// +// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \ +// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +memref.global "private" @global_input : memref<3x3xi64> = dense<[[0 , 1 , 2 ], + [10, 11, 12], + [20, 21, 22]]> + +memref.global "private" @global_identity : memref<3x3xi64> = dense<[[0, 0, 0], + [0, 1, 0], + [0, 0, 0]]> + +memref.global "private" @global_output : memref<3x3xi64> = dense<[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]> + +func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface } + +func.func @main() -> i32 { + %input = memref.get_global @global_input : memref<3x3xi64> + %identity = memref.get_global @global_identity : memref<3x3xi64> + %output = memref.get_global @global_output: memref<3x3xi64> + + %kernelAnchorX = arith.constant 1 : index + %kernelAnchorY = arith.constant 1 : index + %c = arith.constant 0 : i64 + dip.corr_2d %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xi64>, memref<3x3xi64>, memref<3x3xi64>, index, index, i64 + + %printed_output = memref.cast %output : memref<3x3xi64> to memref<*xi64> + call @printMemrefI64(%printed_output) : (memref<*xi64>) -> () + // CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}} + // CHECK{LITERAL}: [[0, 1, 2], + // CHECK{LITERAL}: [10, 11, 12], + // CHECK{LITERAL}: [20, 21, 22]] + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/tests/Dialect/DIP/correlation2D_i8.mlir b/tests/Dialect/DIP/correlation2D_i8.mlir new file mode 100644 index 000000000..10c466c75 --- /dev/null +++ b/tests/Dialect/DIP/correlation2D_i8.mlir @@ -0,0 +1,43 @@ +// +// x86 +// +// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \ +// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +memref.global "private" @global_input : memref<3x3xi8> = dense<[[97, 97, 97], + [97, 97, 97], + [97, 97, 97]]> + +memref.global "private" @global_identity : memref<3x3xi8> = dense<[[0, 0, 0], + [0, 1, 0], + [0, 0, 0]]> + +memref.global "private" @global_output : memref<3x3xi8> = dense<[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]> + +func.func private @printMemrefI8(memref<*xi8>) attributes { llvm.emit_c_interface } + +func.func @main() -> i32 { + %input = memref.get_global @global_input : memref<3x3xi8> + %identity = memref.get_global @global_identity : memref<3x3xi8> + %output = memref.get_global @global_output: memref<3x3xi8> + + %kernelAnchorX = arith.constant 1 : index + %kernelAnchorY = arith.constant 1 : index + %c = arith.constant 0 : i8 + dip.corr_2d %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xi8>, memref<3x3xi8>, memref<3x3xi8>, index, index, i8 + + %printed_output = memref.cast %output : memref<3x3xi8> to memref<*xi8> + call @printMemrefI8(%printed_output) : (memref<*xi8>) -> () + // CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}} + // a is ASCII for 97 + // CHECK{LITERAL}: [[a, a, a], + // CHECK{LITERAL}: [a, a, a], + // CHECK{LITERAL}: [a, a, a]] + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/tests/Dialect/DIP/correlation2D_invalid_type.mlir b/tests/Dialect/DIP/correlation2D_invalid_type.mlir new file mode 100644 index 000000000..7ffb59fee --- /dev/null +++ b/tests/Dialect/DIP/correlation2D_invalid_type.mlir @@ -0,0 +1,78 @@ +// +// x86 +// +// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \ +// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts 2>&1 | FileCheck %s + +memref.global "private" @global_input_f32 : memref<3x3xf32> = dense<[[0. , 1. , 2. ], + [10., 11., 12.], + [20., 21., 22.]]> + +memref.global "private" @global_identity_f32 : memref<3x3xf32> = dense<[[0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.]]> + +memref.global "private" @global_output_f32 : memref<3x3xf32> = dense<[[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]> + +memref.global "private" @global_input_i32 : memref<3x3xi32> = dense<[[0 , 1 , 2 ], + [10, 11, 12], + [20, 21, 22]]> + +memref.global "private" @global_identity_i32 : memref<3x3xi32> = dense<[[0, 0, 0], + [0, 1, 0], + [0, 0, 0]]> + +memref.global "private" @global_output_i32 : memref<3x3xi32> = dense<[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]> + +memref.global "private" @global_input_f128 : memref<3x3xf128> = dense<[[0. , 1. , 2. ], + [10., 11., 12.], + [20., 21., 22.]]> + +memref.global "private" @global_output_f128 : memref<3x3xf128> = dense<[[0., 0., 0. ], + [0., 1., 0.], + [0., 0., 0.]]> + +memref.global "private" @global_identity_f128 : memref<3x3xf128> = dense<[[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]> + +func.func @main() -> i32 { + %input_f32 = memref.get_global @global_input_f32 : memref<3x3xf32> + %identity_f32 = memref.get_global @global_identity_f32 : memref<3x3xf32> + %output_f32 = memref.get_global @global_output_f32 : memref<3x3xf32> + %c_f32 = arith.constant 0. : f32 + + %input_i32 = memref.get_global @global_input_i32 : memref<3x3xi32> + %identity_i32 = memref.get_global @global_identity_i32 : memref<3x3xi32> + %output_i32 = memref.get_global @global_output_i32 : memref<3x3xi32> + %c_i32 = arith.constant 0 : i32 + + %input_f128 = memref.get_global @global_input_f128 : memref<3x3xf128> + %identity_f128 = memref.get_global @global_identity_f128 : memref<3x3xf128> + %output_f128 = memref.get_global @global_output_f128 : memref<3x3xf128> + %c_f128 = arith.constant 0. : f128 + + %kernelAnchorX = arith.constant 1 : index + %kernelAnchorY = arith.constant 1 : index + + dip.corr_2d %input_i32, %identity_f32, %output_f32, %kernelAnchorX, %kernelAnchorY, %c_f32 : memref<3x3xi32>, memref<3x3xf32>, memref<3x3xf32>, index, index, f32 + // CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type + + dip.corr_2d %input_f32, %identity_i32, %output_f32, %kernelAnchorX, %kernelAnchorY, %c_f32 : memref<3x3xf32>, memref<3x3xi32>, memref<3x3xf32>, index, index, f32 + // CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type + + dip.corr_2d %input_f32, %identity_f32, %output_i32, %kernelAnchorX, %kernelAnchorY, %c_f32 : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xi32>, index, index, f32 + // CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type + + dip.corr_2d %input_f32, %identity_f32, %output_f32, %kernelAnchorX, %kernelAnchorY, %c_i32 : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, index, index, i32 + // CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type + + dip.corr_2d %input_f128, %identity_f128, %output_f128, %kernelAnchorX, %kernelAnchorY, %c_f128 : memref<3x3xf128>, memref<3x3xf128>, memref<3x3xf128>, index, index, f128 + // CHECK: 'dip.corr_2d' op supports only f32, f64 and integer types. 'f128'is passed + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py index e1add15f1..7be79430b 100644 --- a/tests/lit.cfg.py +++ b/tests/lit.cfg.py @@ -57,8 +57,12 @@ 'buddy-opt', 'buddy-translate', 'buddy-container-test', - 'buddy-audio-container-test' + 'buddy-audio-container-test', + 'mlir-cpu-runner', ] +tools.extend([ + ToolSubst('%mlir_runner_utils_dir', config.mlir_runner_utils_dir, unresolved='ignore'), +]) if config.buddy_enable_opencv == "ON": tools.append('buddy-image-container-test') diff --git a/tests/lit.site.cfg.py.in b/tests/lit.site.cfg.py.in index 68fc0b481..592ec9a18 100644 --- a/tests/lit.site.cfg.py.in +++ b/tests/lit.site.cfg.py.in @@ -32,6 +32,7 @@ config.host_arch = "@HOST_ARCH@" config.buddy_src_root = "@CMAKE_SOURCE_DIR@" config.buddy_obj_root = "@CMAKE_BINARY_DIR@" config.buddy_enable_opencv = "@BUDDY_ENABLE_OPENCV@" +config.mlir_runner_utils_dir = "@LLVM_LIBS_DIR@" # Support substitution of the tools_dir with user parameters. This is # used when we can't determine the tool dir at configuration time.