Skip to content

Commit

Permalink
[DIP] Infer operation data type based on its params for Corr2D (buddy…
Browse files Browse the repository at this point in the history
…-compiler#63)

* Add lit-test for correlation2D

* Add correlation i32 test

* Add support for i32 correlation 2d

* Enable default attribute printing for DIP

* Add a check for compatbility return type

* Update lit test to find mlir cpu utils

* Make output to be param for Corr2d

  - having it as a return value causes to segfault correlation2D sample
    - it is not clear how to return MemRef from C-interface

* Mark attributes with <> in dip.mlir

* Add a operand type check for Corr2D op and a test

  - Make sure that input, kernel, output and constant have the same
    value and use as inferred type
  - Adding a negative lit test to check params of the op

* Fix review comments

* Add support for i8,i64,f64

* Fix correlation2D_f64 test

 * by constructing F64 correctly

* Fix review comment and formatting issues

* Fix more review comment

  - extend correlation2d_invalid_type test to cover a condition for supported types
  - add comments for utility functions

* Remove insertZeroConstantOp from LowerDIPPass

  - it is present in DIPItility.h

* Trivial changes

Co-authored-by: meshtag <prathameshtagore@gmail.com>
  • Loading branch information
ArtemSkrebkov and meshtag authored Aug 8, 2022
1 parent a42d012 commit 39b1b78
Show file tree
Hide file tree
Showing 13 changed files with 388 additions and 32 deletions.
4 changes: 2 additions & 2 deletions examples/DIPDialect/dip.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
func.func @corr_2d_constant_padding(%inputImage : memref<?x?xf32>, %kernel : memref<?x?xf32>, %outputImage : memref<?x?xf32>, %centerX : index, %centerY : index, %constantValue : f32)
{
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
dip.corr_2d <CONSTANT_PADDING> %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
return
}

func.func @corr_2d_replicate_padding(%inputImage : memref<?x?xf32>, %kernel : memref<?x?xf32>, %outputImage : memref<?x?xf32>, %centerX : index, %centerY : index, %constantValue : f32)
{
dip.corr_2d REPLICATE_PADDING %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
dip.corr_2d <REPLICATE_PADDING> %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
return
}

Expand Down
1 change: 1 addition & 0 deletions include/Dialect/DIP/DIPDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def DIP_Dialect : Dialect {
of developing a MLIR backend for performing image processing operations
such as 2D Correlation, Morphological processing, etc.
}];
let useDefaultAttributePrinterParser = 1;
let cppNamespace = "::buddy::dip";
}

Expand Down
45 changes: 24 additions & 21 deletions include/Dialect/DIP/DIPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,31 @@ def DIP_InterpolationType : I32EnumAttr<"InterpolationType",
let cppNamespace = "::buddy::dip";
}

def DIP_BoundaryOptionAttr : EnumAttr<DIP_Dialect, DIP_BoundaryOption, "boundary_option">;
def DIP_BoundaryOptionAttr : EnumAttr<DIP_Dialect, DIP_BoundaryOption, "boundary_option"> {
let assemblyFormat = "`<` $value `>`";
}
def DIP_InterpolationAttr : EnumAttr<DIP_Dialect, DIP_InterpolationType, "interpolation_type">;

def DIP_Corr2DOp : DIP_Op<"corr_2d">
{
def DIP_Corr2DOp : DIP_Op<"corr_2d"> {
let summary = [{This operation is used for performing 2D correlation on an image.
The 2D correlation API provided by the linalg dialect is more suited for
applications in which boundary extrapolation is not explicitly required.
Due to this, dimensions of output are always less than the input dimensions after
using linalg dialect's 2D correlation API.

dip.corr_2d performs boundary extrapolation for making the size of the output image
equal to the size of the input image. Boundary extrapolation can be done using
different methods, supported options are :
a. Constant Padding : Uses a constant for padding whole extra region in input image
for obtaining the boundary extrapolated output image. (kkk|abcdefg|kkk)
b. Replicate Padding : Uses last/first element of respective column/row for padding
the extra region used for creating the boundary extrapolated output image. (aaa|abcdefg|ggg)
For example:
The 2D correlation API provided by the linalg dialect is more suited for
applications in which boundary extrapolation is not explicitly required.
Due to this, dimensions of output are always less than the input dimensions after
using linalg dialect's 2D correlation API.

dip.corr_2d performs boundary extrapolation for making the size of the output image
equal to the size of the input image. Boundary extrapolation can be done using
different methods, supported options are:
a. Constant Padding : Uses a constant for padding whole extra region in input image
for obtaining the boundary extrapolated output image. (kkk|abcdefg|kkk)
b. Replicate Padding : Uses last/first element of respective column/row for padding
the extra region used for creating the boundary extrapolated output image. (aaa|abcdefg|ggg)
For example:

```mlir
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, index
```
```mlir
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, index
```
}];

let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "inputMemref",
Expand All @@ -87,7 +88,9 @@ def DIP_Corr2DOp : DIP_Op<"corr_2d">
[MemRead]>:$memrefK,
Arg<AnyRankedOrUnrankedMemRef, "outputMemref",
[MemRead]>:$memrefCO,
Index : $centerX, Index : $centerY, F32 : $constantValue,
Index : $centerX,
Index : $centerY,
AnyTypeOf<[AnyI8, AnyI32, AnyI64, AnyFloat]> : $constantValue,
DIP_BoundaryOptionAttr:$boundary_option);

let assemblyFormat = [{
Expand Down
45 changes: 42 additions & 3 deletions include/Utils/DIPUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,44 @@

#include "Utils/Utils.h"

// Inserts a constant op with value 0 into a location `loc` based on type
// `type`. Supported types are : f32, f64, integer types
Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc,
Type elemTy) {
Value op = {};
auto bitWidth = elemTy.getIntOrFloatBitWidth();
if (elemTy.isF32() || elemTy.isF64()) {
FloatType type =
elemTy.isF32() ? FloatType::getF32(ctx) : FloatType::getF64(ctx);
auto zero = APFloat::getZero(type.getFloatSemantics());
op = builder.create<ConstantFloatOp>(loc, zero, type);
} else if (elemTy.isInteger(bitWidth)) {
IntegerType type = IntegerType::get(ctx, bitWidth);
op = builder.create<ConstantIntOp>(loc, 0, type);
}

return op;
}

// Inserts FMA operation into a given location `loc` based on type `type`.
// Note: FMA is done by Multiply and Add for integer types, because there is no
// dedicated FMA operation for them.
// Supported types: f32, f64, integer types
Value insertFMAOp(OpBuilder &builder, Location loc, VectorType type,
Value inputVec, Value kernelVec, Value outputVec) {
Value res = {};
auto elemTy = type.getElementType();
auto bitWidth = elemTy.getIntOrFloatBitWidth();
if (elemTy.isF32() || elemTy.isF64()) {
res = builder.create<vector::FMAOp>(loc, inputVec, kernelVec, outputVec);
} else if (elemTy.isInteger(bitWidth)) {
Value mul = builder.create<arith::MulIOp>(loc, inputVec, kernelVec);
res = builder.create<arith::AddIOp>(loc, mul, outputVec);
}

return res;
}

// Calculate result of FMA and store it in output memref. This function cannot
// handle tail processing.
void calcAndStoreFMAwoTailProcessing(OpBuilder &builder, Location loc,
Expand All @@ -32,7 +70,8 @@ void calcAndStoreFMAwoTailProcessing(OpBuilder &builder, Location loc,
Value beginIdx, Value endIdx) {
Value outputVec = builder.create<LoadOp>(loc, vecType, output,
ValueRange{beginIdx, endIdx});
Value resVec = builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
Value resVec =
insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec);
builder.create<StoreOp>(loc, resVec, output, ValueRange{beginIdx, endIdx});
}

Expand Down Expand Up @@ -72,7 +111,7 @@ void calcAndStoreFMAwTailProcessing(OpBuilder &builder, Location loc,
Value outputVec = builder.create<LoadOp>(loc, vecType, output,
ValueRange{beginIdx, endIdx});
Value resVec =
builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec);
builder.create<StoreOp>(loc, resVec, output,
ValueRange{beginIdx, endIdx});

Expand All @@ -85,7 +124,7 @@ void calcAndStoreFMAwTailProcessing(OpBuilder &builder, Location loc,
loc, vecType, output, ValueRange{beginIdx, endIdx}, extraElemMask,
zeroPadding);
Value resVec =
builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
insertFMAOp(builder, loc, vecType, inputVec, kernelVec, outputVec);
builder.create<MaskedStoreOp>(loc, output, ValueRange{beginIdx, endIdx},
extraElemMask, resVec);

Expand Down
30 changes: 25 additions & 5 deletions lib/Conversion/LowerDIP/LowerDIPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"

#include "DIP/DIPDialect.h"
Expand Down Expand Up @@ -56,7 +60,7 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
LogicalResult matchAndRewrite(dip::Corr2DOp op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto ctx = op->getContext();
auto *ctx = op->getContext();

// Create constant indices.
Value c0 = rewriter.create<ConstantIndexOp>(loc, 0);
Expand All @@ -72,7 +76,24 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
auto boundaryOptionAttr = op.boundary_option();
Value strideVal = rewriter.create<ConstantIndexOp>(loc, stride);

FloatType f32 = FloatType::getF32(ctx);
auto inElemTy = input.getType().cast<MemRefType>().getElementType();
auto kElemTy = kernel.getType().cast<MemRefType>().getElementType();
auto outElemTy = output.getType().cast<MemRefType>().getElementType();
auto constElemTy = constantValue.getType();
if (inElemTy != kElemTy || kElemTy != outElemTy ||
outElemTy != constElemTy) {
return op->emitOpError() << "input, kernel, output and constant must "
"have the same element type";
}
// NB: we can infer element type for all operation to be the same as input
// since we verified that the operand types are the same
auto elemTy = inElemTy;
auto bitWidth = elemTy.getIntOrFloatBitWidth();
if (!elemTy.isF64() && !elemTy.isF32() && !elemTy.isInteger(bitWidth)) {
return op->emitOpError() << "supports only f32, f64 and integer types. "
<< elemTy << "is passed";
}

IntegerType i1 = IntegerType::get(ctx, 1);

// Create DimOp.
Expand All @@ -90,11 +111,10 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
kernelSize};
SmallVector<int64_t, 8> steps{1, 1, stride, 1};

VectorType vectorTy32 = VectorType::get({stride}, f32);
VectorType vectorTy32 = VectorType::get({stride}, elemTy);
VectorType vectorMaskTy = VectorType::get({stride}, i1);

Value zeroPaddingElem =
rewriter.create<ConstantFloatOp>(loc, (APFloat)(float)0, f32);
Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, elemTy);
Value zeroPadding =
rewriter.create<BroadcastOp>(loc, vectorTy32, zeroPaddingElem);

Expand Down
42 changes: 42 additions & 0 deletions tests/Dialect/DIP/correlation2D_f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//
// x86
//
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @global_input : memref<3x3xf32> = dense<[[0. , 1. , 2. ],
[10., 11., 12.],
[20., 21., 22.]]>

memref.global "private" @global_identity : memref<3x3xf32> = dense<[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]]>

memref.global "private" @global_output : memref<3x3xf32> = dense<[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]>
func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }

func.func @main() -> i32 {
%input = memref.get_global @global_input : memref<3x3xf32>
%identity = memref.get_global @global_identity : memref<3x3xf32>
%output = memref.get_global @global_output : memref<3x3xf32>

%kernelAnchorX = arith.constant 1 : index
%kernelAnchorY = arith.constant 1 : index
%c = arith.constant 0. : f32
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, index, index, f32

%printed_output = memref.cast %output : memref<3x3xf32> to memref<*xf32>
call @printMemrefF32(%printed_output) : (memref<*xf32>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
// CHECK{LITERAL}: [[0, 1, 2],
// CHECK{LITERAL}: [10, 11, 12],
// CHECK{LITERAL}: [20, 21, 22]]

%ret = arith.constant 0 : i32
return %ret : i32
}
42 changes: 42 additions & 0 deletions tests/Dialect/DIP/correlation2D_f64.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//
// x86
//
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @global_input : memref<3x3xf64> = dense<[[0. , 1. , 2. ],
[10., 11., 12.],
[20., 21., 22.]]>

memref.global "private" @global_identity : memref<3x3xf64> = dense<[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]]>

memref.global "private" @global_output : memref<3x3xf64> = dense<[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]>
func.func private @printMemrefF64(memref<*xf64>) attributes { llvm.emit_c_interface }

func.func @main() -> i32 {
%input = memref.get_global @global_input : memref<3x3xf64>
%identity = memref.get_global @global_identity : memref<3x3xf64>
%output = memref.get_global @global_output : memref<3x3xf64>

%kernelAnchorX = arith.constant 1 : index
%kernelAnchorY = arith.constant 1 : index
%c = arith.constant 0. : f64
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xf64>, memref<3x3xf64>, memref<3x3xf64>, index, index, f64

%printed_output = memref.cast %output : memref<3x3xf64> to memref<*xf64>
call @printMemrefF64(%printed_output) : (memref<*xf64>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
// CHECK{LITERAL}: [[0, 1, 2],
// CHECK{LITERAL}: [10, 11, 12],
// CHECK{LITERAL}: [20, 21, 22]]

%ret = arith.constant 0 : i32
return %ret : i32
}
41 changes: 41 additions & 0 deletions tests/Dialect/DIP/correlation2D_i32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//
// x86
//
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @global_input : memref<3x3xi32> = dense<[[0 , 1 , 2 ],
[10, 11, 12],
[20, 21, 22]]>

memref.global "private" @global_identity : memref<3x3xi32> = dense<[[0, 0, 0],
[0, 1, 0],
[0, 0, 0]]>

memref.global "private" @global_output : memref<3x3xi32> = dense<[[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]>
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

func.func @main() -> i32 {
%input = memref.get_global @global_input : memref<3x3xi32>
%identity = memref.get_global @global_identity : memref<3x3xi32>
%output = memref.get_global @global_output: memref<3x3xi32>

%kernelAnchorX = arith.constant 1 : index
%kernelAnchorY = arith.constant 1 : index
%c = arith.constant 0 : i32
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xi32>, memref<3x3xi32>, memref<3x3xi32>, index, index, i32

%printed_output = memref.cast %output : memref<3x3xi32> to memref<*xi32>
call @printMemrefI32(%printed_output) : (memref<*xi32>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
// CHECK{LITERAL}: [[0, 1, 2],
// CHECK{LITERAL}: [10, 11, 12],
// CHECK{LITERAL}: [20, 21, 22]]
%ret = arith.constant 0 : i32
return %ret : i32
}
42 changes: 42 additions & 0 deletions tests/Dialect/DIP/correlation2D_i64.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//
// x86
//
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @global_input : memref<3x3xi64> = dense<[[0 , 1 , 2 ],
[10, 11, 12],
[20, 21, 22]]>

memref.global "private" @global_identity : memref<3x3xi64> = dense<[[0, 0, 0],
[0, 1, 0],
[0, 0, 0]]>

memref.global "private" @global_output : memref<3x3xi64> = dense<[[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]>

func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }

func.func @main() -> i32 {
%input = memref.get_global @global_input : memref<3x3xi64>
%identity = memref.get_global @global_identity : memref<3x3xi64>
%output = memref.get_global @global_output: memref<3x3xi64>

%kernelAnchorX = arith.constant 1 : index
%kernelAnchorY = arith.constant 1 : index
%c = arith.constant 0 : i64
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %kernelAnchorX, %kernelAnchorY, %c : memref<3x3xi64>, memref<3x3xi64>, memref<3x3xi64>, index, index, i64

%printed_output = memref.cast %output : memref<3x3xi64> to memref<*xi64>
call @printMemrefI64(%printed_output) : (memref<*xi64>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
// CHECK{LITERAL}: [[0, 1, 2],
// CHECK{LITERAL}: [10, 11, 12],
// CHECK{LITERAL}: [20, 21, 22]]
%ret = arith.constant 0 : i32
return %ret : i32
}
Loading

0 comments on commit 39b1b78

Please sign in to comment.