Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer operation data type based on its params for Corr2D #63

Merged
merged 16 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/DIPDialect/dip.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
func.func @corr_2d_constant_padding(%inputImage : memref<?x?xf32>, %kernel : memref<?x?xf32>, %outputImage : memref<?x?xf32>, %centerX : index, %centerY : index, %constantValue : f32)
{
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
dip.corr_2d <CONSTANT_PADDING> %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
return
}

func.func @corr_2d_replicate_padding(%inputImage : memref<?x?xf32>, %kernel : memref<?x?xf32>, %outputImage : memref<?x?xf32>, %centerX : index, %centerY : index, %constantValue : f32)
{
dip.corr_2d REPLICATE_PADDING %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
dip.corr_2d <REPLICATE_PADDING> %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
return
}

Expand Down
1 change: 1 addition & 0 deletions include/Dialect/DIP/DIPDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def DIP_Dialect : Dialect {
of developing a MLIR backend for performing image processing operations
such as 2D Correlation, Morphological processing, etc.
}];
let useDefaultAttributePrinterParser = 1;
let cppNamespace = "::buddy::dip";
}

Expand Down
45 changes: 24 additions & 21 deletions include/Dialect/DIP/DIPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,31 @@ def DIP_InterpolationType : I32EnumAttr<"InterpolationType",
let cppNamespace = "::buddy::dip";
}

def DIP_BoundaryOptionAttr : EnumAttr<DIP_Dialect, DIP_BoundaryOption, "boundary_option">;
def DIP_BoundaryOptionAttr : EnumAttr<DIP_Dialect, DIP_BoundaryOption, "boundary_option"> {
let assemblyFormat = "`<` $value `>`";
}
def DIP_InterpolationAttr : EnumAttr<DIP_Dialect, DIP_InterpolationType, "interpolation_type">;

def DIP_Corr2DOp : DIP_Op<"corr_2d">
{
def DIP_Corr2DOp : DIP_Op<"corr_2d"> {
let summary = [{This operation is used for performing 2D correlation on an image.
The 2D correlation API provided by the linalg dialect is more suited for
applications in which boundary extrapolation is not explicitly required.
Due to this, dimensions of output are always less than the input dimensions after
using linalg dialect's 2D correlation API.

dip.corr_2d performs boundary extrapolation for making the size of the output image
equal to the size of the input image. Boundary extrapolation can be done using
different methods, supported options are :
a. Constant Padding : Uses a constant for padding whole extra region in input image
for obtaining the boundary extrapolated output image. (kkk|abcdefg|kkk)
b. Replicate Padding : Uses last/first element of respective column/row for padding
the extra region used for creating the boundary extrapolated output image. (aaa|abcdefg|ggg)
For example:
The 2D correlation API provided by the linalg dialect is more suited for
applications in which boundary extrapolation is not explicitly required.
Due to this, dimensions of output are always less than the input dimensions after
using linalg dialect's 2D correlation API.

dip.corr_2d performs boundary extrapolation for making the size of the output image
equal to the size of the input image. Boundary extrapolation can be done using
different methods, supported options are:
a. Constant Padding : Uses a constant for padding whole extra region in input image
for obtaining the boundary extrapolated output image. (kkk|abcdefg|kkk)
b. Replicate Padding : Uses last/first element of respective column/row for padding
the extra region used for creating the boundary extrapolated output image. (aaa|abcdefg|ggg)
For example:

```mlir
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, index
```
```mlir
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, index
```
}];

let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "inputMemref",
Expand All @@ -87,7 +88,9 @@ def DIP_Corr2DOp : DIP_Op<"corr_2d">
[MemRead]>:$memrefK,
Arg<AnyRankedOrUnrankedMemRef, "outputMemref",
[MemRead]>:$memrefCO,
Index : $centerX, Index : $centerY, F32 : $constantValue,
Index : $centerX,
Index : $centerY,
AnyTypeOf<[AnyI8, AnyI32, AnyI64, AnyFloat]> : $constantValue,
DIP_BoundaryOptionAttr:$boundary_option);

let assemblyFormat = [{
Expand Down
45 changes: 42 additions & 3 deletions include/Utils/DIPUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,44 @@

#include "Utils/Utils.h"

// Inserts a constant op with value 0 into a location `loc` based on type
// `type`. Supported types are : f32, f64, integer types
Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc,
Type elemTy) {
Value op = {};
auto bitWidth = elemTy.getIntOrFloatBitWidth();
if (elemTy.isF32() || elemTy.isF64()) {
FloatType type =
elemTy.isF32() ? FloatType::getF32(ctx) : FloatType::getF64(ctx);
auto zero = APFloat::getZero(type.getFloatSemantics());
op = builder.create<ConstantFloatOp>(loc, zero, type);
} else if (elemTy.isInteger(bitWidth)) {
IntegerType type = IntegerType::get(ctx, bitWidth);
op = builder.create<ConstantIntOp>(loc, 0, type);
}

return op;
}

// Inserts FMA operation into a given location `loc` based on type `type`.
// Note: FMA is done by Multiply and Add for integer types, because there is no
// dedicated FMA operation for them.
// Supported types: f32, f64, integer types
Value insertFMAOp(OpBuilder &builder, Location loc, VectorType type,
Value inputVec, Value kernelVec, Value outputVec) {
Value res = {};
auto elemTy = type.getElementType();
auto bitWidth = elemTy.getIntOrFloatBitWidth();
if (elemTy.isF32() || elemTy.isF64()) {
res = builder.create<vector::FMAOp>(loc, inputVec, kernelVec, outputVec);
} else if (elemTy.isInteger(bitWidth)) {
Value mul = builder.create<arith::MulIOp>(loc, inputVec, kernelVec);
res = builder.create<arith::AddIOp>(loc, mul, outputVec);
}

return res;
}

// Calculate result of FMA and store it in output memref. This function cannot
// handle tail processing.
void calcAndStoreFMAwoTailProcessing(OpBuilder &builder, Location loc,
Expand All @@ -32,7 +70,8 @@ void calcAndStoreFMAwoTailProcessing(OpBuilder &builder, Location loc,
Value beginIdx, Value endIdx) {
Value outputVec = builder.create<LoadOp>(loc, vecType, output,
ValueRange{beginIdx, endIdx});
Value resVec = builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
Value resVec =
insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec);
builder.create<StoreOp>(loc, resVec, output, ValueRange{beginIdx, endIdx});
}

Expand Down Expand Up @@ -72,7 +111,7 @@ void calcAndStoreFMAwTailProcessing(OpBuilder &builder, Location loc,
Value outputVec = builder.create<LoadOp>(loc, vecType, output,
ValueRange{beginIdx, endIdx});
Value resVec =
builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec);
builder.create<StoreOp>(loc, resVec, output,
ValueRange{beginIdx, endIdx});

Expand All @@ -85,7 +124,7 @@ void calcAndStoreFMAwTailProcessing(OpBuilder &builder, Location loc,
loc, vecType, output, ValueRange{beginIdx, endIdx}, extraElemMask,
zeroPadding);
Value resVec =
builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec);
builder.create<MaskedStoreOp>(loc, output, ValueRange{beginIdx, endIdx},
extraElemMask, resVec);

Expand Down
30 changes: 25 additions & 5 deletions lib/Conversion/LowerDIP/LowerDIPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"

#include "DIP/DIPDialect.h"
Expand Down Expand Up @@ -56,7 +60,7 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
LogicalResult matchAndRewrite(dip::Corr2DOp op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto ctx = op->getContext();
auto *ctx = op->getContext();

// Create constant indices.
Value c0 = rewriter.create<ConstantIndexOp>(loc, 0);
Expand All @@ -72,7 +76,24 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
auto boundaryOptionAttr = op.boundary_option();
Value strideVal = rewriter.create<ConstantIndexOp>(loc, stride);

FloatType f32 = FloatType::getF32(ctx);
auto inElemTy = input.getType().cast<MemRefType>().getElementType();
auto kElemTy = kernel.getType().cast<MemRefType>().getElementType();
auto outElemTy = output.getType().cast<MemRefType>().getElementType();
auto constElemTy = constantValue.getType();
if (inElemTy != kElemTy || kElemTy != outElemTy ||
outElemTy != constElemTy) {
return op->emitOpError() << "input, kernel, output and constant must "
"have the same element type";
}
// NB: we can infer element type for all operation to be the same as input
// since we verified that the operand types are the same
auto elemTy = inElemTy;
auto bitWidth = elemTy.getIntOrFloatBitWidth();
if (!elemTy.isF64() && !elemTy.isF32() && !elemTy.isInteger(bitWidth)) {
return op->emitOpError() << "supports only f32, f64 and integer types. "
<< elemTy << "is passed";
}

IntegerType i1 = IntegerType::get(ctx, 1);

// Create DimOp.
Expand All @@ -90,11 +111,10 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
kernelSize};
SmallVector<int64_t, 8> steps{1, 1, stride, 1};

VectorType vectorTy32 = VectorType::get({stride}, f32);
VectorType vectorTy32 = VectorType::get({stride}, elemTy);
VectorType vectorMaskTy = VectorType::get({stride}, i1);

Value zeroPaddingElem =
rewriter.create<ConstantFloatOp>(loc, (APFloat)(float)0, f32);
Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, elemTy);
Value zeroPadding =
rewriter.create<BroadcastOp>(loc, vectorTy32, zeroPaddingElem);

Expand Down
42 changes: 42 additions & 0 deletions tests/Dialect/DIP/correlation2D_f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//
// x86
//
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @global_input : memref<3x3xf32> = dense<[[0. , 1. , 2. ],
[10., 11., 12.],
[20., 21., 22.]]>

memref.global "private" @global_identity : memref<3x3xf32> = dense<[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]]>

memref.global "private" @global_output : memref<3x3xf32> = dense<[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]>
func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }

func.func @main() -> i32 {
%input = memref.get_global @global_input : memref<3x3xf32>
%identity = memref.get_global @global_identity : memref<3x3xf32>
%output = memref.get_global @global_output : memref<3x3xf32>

%kernelAnchorX = arith.constant 1 : index
%kernelAnchorY = arith.constant 1 : index
%c = arith.constant 0. : f32
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, index, index, f32

%printed_output = memref.cast %output : memref<3x3xf32> to memref<*xf32>
call @printMemrefF32(%printed_output) : (memref<*xf32>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
// CHECK{LITERAL}: [[0, 1, 2],
// CHECK{LITERAL}: [10, 11, 12],
// CHECK{LITERAL}: [20, 21, 22]]

%ret = arith.constant 0 : i32
return %ret : i32
}
42 changes: 42 additions & 0 deletions tests/Dialect/DIP/correlation2D_f64.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//
// x86
//
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @global_input : memref<3x3xf64> = dense<[[0. , 1. , 2. ],
[10., 11., 12.],
[20., 21., 22.]]>

memref.global "private" @global_identity : memref<3x3xf64> = dense<[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]]>

memref.global "private" @global_output : memref<3x3xf64> = dense<[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]>
func.func private @printMemrefF64(memref<*xf64>) attributes { llvm.emit_c_interface }

func.func @main() -> i32 {
%input = memref.get_global @global_input : memref<3x3xf64>
%identity = memref.get_global @global_identity : memref<3x3xf64>
%output = memref.get_global @global_output : memref<3x3xf64>

%kernelAnchorX = arith.constant 1 : index
%kernelAnchorY = arith.constant 1 : index
%c = arith.constant 0. : f64
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xf64>, memref<3x3xf64>, memref<3x3xf64>, index, index, f64

%printed_output = memref.cast %output : memref<3x3xf64> to memref<*xf64>
call @printMemrefF64(%printed_output) : (memref<*xf64>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
// CHECK{LITERAL}: [[0, 1, 2],
// CHECK{LITERAL}: [10, 11, 12],
// CHECK{LITERAL}: [20, 21, 22]]

%ret = arith.constant 0 : i32
return %ret : i32
}
41 changes: 41 additions & 0 deletions tests/Dialect/DIP/correlation2D_i32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//
// x86
//
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @global_input : memref<3x3xi32> = dense<[[0 , 1 , 2 ],
[10, 11, 12],
[20, 21, 22]]>

memref.global "private" @global_identity : memref<3x3xi32> = dense<[[0, 0, 0],
[0, 1, 0],
[0, 0, 0]]>

memref.global "private" @global_output : memref<3x3xi32> = dense<[[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]>
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

func.func @main() -> i32 {
%input = memref.get_global @global_input : memref<3x3xi32>
%identity = memref.get_global @global_identity : memref<3x3xi32>
%output = memref.get_global @global_output: memref<3x3xi32>

%kernelAnchorX = arith.constant 1 : index
%kernelAnchorY = arith.constant 1 : index
%c = arith.constant 0 : i32
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xi32>, memref<3x3xi32>, memref<3x3xi32>, index, index, i32

%printed_output = memref.cast %output : memref<3x3xi32> to memref<*xi32>
call @printMemrefI32(%printed_output) : (memref<*xi32>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
// CHECK{LITERAL}: [[0, 1, 2],
// CHECK{LITERAL}: [10, 11, 12],
// CHECK{LITERAL}: [20, 21, 22]]
%ret = arith.constant 0 : i32
return %ret : i32
}
42 changes: 42 additions & 0 deletions tests/Dialect/DIP/correlation2D_i64.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//
// x86
//
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @global_input : memref<3x3xi64> = dense<[[0 , 1 , 2 ],
[10, 11, 12],
[20, 21, 22]]>

memref.global "private" @global_identity : memref<3x3xi64> = dense<[[0, 0, 0],
[0, 1, 0],
[0, 0, 0]]>

memref.global "private" @global_output : memref<3x3xi64> = dense<[[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]>

func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }

func.func @main() -> i32 {
%input = memref.get_global @global_input : memref<3x3xi64>
%identity = memref.get_global @global_identity : memref<3x3xi64>
%output = memref.get_global @global_output: memref<3x3xi64>

%kernelAnchorX = arith.constant 1 : index
%kernelAnchorY = arith.constant 1 : index
%c = arith.constant 0 : i64
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xi64>, memref<3x3xi64>, memref<3x3xi64>, index, index, i64

%printed_output = memref.cast %output : memref<3x3xi64> to memref<*xi64>
call @printMemrefI64(%printed_output) : (memref<*xi64>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
// CHECK{LITERAL}: [[0, 1, 2],
// CHECK{LITERAL}: [10, 11, 12],
// CHECK{LITERAL}: [20, 21, 22]]
%ret = arith.constant 0 : i32
return %ret : i32
}
Loading