Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer operation data type based on its params for Corr2D #63

Merged
merged 16 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/DIPDialect/dip.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
func.func @corr_2d_constant_padding(%inputImage : memref<?x?xf32>, %kernel : memref<?x?xf32>, %outputImage : memref<?x?xf32>, %centerX : index, %centerY : index, %constantValue : f32)
{
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
dip.corr_2d <CONSTANT_PADDING> %inputImage, %kernel, %outputImage, %centerX, %centerY, %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
return
}

func.func @corr_2d_replicate_padding(%inputImage : memref<?x?xf32>, %kernel : memref<?x?xf32>, %outputImage : memref<?x?xf32>, %centerX : index, %centerY : index, %constantValue : f32)
{
dip.corr_2d REPLICATE_PADDING %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
dip.corr_2d <REPLICATE_PADDING> %inputImage, %kernel, %outputImage, %centerX, %centerY , %constantValue : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, f32
return
}

Expand Down
1 change: 1 addition & 0 deletions include/Dialect/DIP/DIPDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def DIP_Dialect : Dialect {
of developing a MLIR backend for performing image processing operations
such as 2D Correlation, Morphological processing, etc.
}];
let useDefaultAttributePrinterParser = 1;
let cppNamespace = "::buddy::dip";
}

Expand Down
49 changes: 26 additions & 23 deletions include/Dialect/DIP/DIPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,43 +55,46 @@ def DIP_InterpolationType : I32EnumAttr<"InterpolationType",
let cppNamespace = "::buddy::dip";
}

def DIP_BoundaryOptionAttr : EnumAttr<DIP_Dialect, DIP_BoundaryOption, "boundary_option">;
def DIP_BoundaryOptionAttr : EnumAttr<DIP_Dialect, DIP_BoundaryOption, "boundary_option"> {
let assemblyFormat = "`<` $value `>`";
}
def DIP_InterpolationAttr : EnumAttr<DIP_Dialect, DIP_InterpolationType, "interpolation_type">;

def DIP_Corr2DOp : DIP_Op<"corr_2d">
{
def DIP_Corr2DOp : DIP_Op<"corr_2d"> {
let summary = [{This operation is used for performing 2D correlation on an image.
The 2D correlation API provided by the linalg dialect is more suited for
applications in which boundary extrapolation is not explicitly required.
Due to this, dimensions of output are always less than the input dimensions after
using linalg dialect's 2D correlation API.

dip.corr_2d performs boundary extrapolation for making the size of the output image
equal to the size of the input image. Boundary extrapolation can be done using
different methods, supported options are :
a. Constant Padding : Uses a constant for padding whole extra region in input image
for obtaining the boundary extrapolated output image. (kkk|abcdefg|kkk)
b. Replicate Padding : Uses last/first element of respective column/row for padding
the extra region used for creating the boundary extrapolated output image. (aaa|abcdefg|ggg)
For example:
The 2D correlation API provided by the linalg dialect is more suited for
applications in which boundary extrapolation is not explicitly required.
Due to this, dimensions of output are always less than the input dimensions after
using linalg dialect's 2D correlation API.

dip.corr_2d performs boundary extrapolation for making the size of the output image
equal to the size of the input image. Boundary extrapolation can be done using
different methods, supported options are:
a. Constant Padding : Uses a constant for padding whole extra region in input image
for obtaining the boundary extrapolated output image. (kkk|abcdefg|kkk)
b. Replicate Padding : Uses last/first element of respective column/row for padding
the extra region used for creating the boundary extrapolated output image. (aaa|abcdefg|ggg)
For example:

```mlir
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, index
```
```mlir
dip.corr_2d CONSTANT_PADDING %inputImage, %kernel, %output, %centerX, %centerY, %constantValue
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, index, index, index
```
}];

let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "inputMemref",
[MemRead]>:$memrefI,
Arg<AnyRankedOrUnrankedMemRef, "kernelMemref",
[MemRead]>:$memrefK,
Arg<AnyRankedOrUnrankedMemRef, "outputMemref",
[MemRead]>:$memrefCO,
Index : $centerX, Index : $centerY, F32 : $constantValue,
[MemWrite]>:$memrefO,
Index : $centerX,
Index : $centerY,
AnyTypeOf<[AnyI32, AnyFloat]> : $constantValue,
DIP_BoundaryOptionAttr:$boundary_option);

let assemblyFormat = [{
$boundary_option $memrefI `,` $memrefK `,` $memrefCO `,` $centerX `,` $centerY `,` $constantValue attr-dict `:` type($memrefI) `,` type($memrefK) `,` type($memrefCO) `,` type($centerX) `,` type($centerY) `,` type($constantValue)
$boundary_option $memrefI `,` $memrefK `,` $memrefO `,` $centerX `,` $centerY `,` $constantValue attr-dict `:` type($memrefI) `,` type($memrefK) `,` type($memrefO) `,` type($centerX) `,` type($centerY) `,` type($constantValue)
}];
}

Expand Down
31 changes: 26 additions & 5 deletions include/Utils/DIPUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ void calcAndStoreFMAwoTailProcessing(OpBuilder &builder, Location loc,
Value beginIdx, Value endIdx) {
Value outputVec = builder.create<LoadOp>(loc, vecType, output,
ValueRange{beginIdx, endIdx});
Value resVec = builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
Value resVec = {};
auto elemTy = vecType.getElementType();
if (elemTy.isF32()) {
resVec = builder.create<vector::FMAOp>(loc, inputVec, kernelVec, outputVec);
} else if (elemTy.isInteger(32)) {
Value mulVec = builder.create<arith::MulIOp>(loc, inputVec, kernelVec);
resVec = builder.create<arith::AddIOp>(loc, mulVec, outputVec);
}
builder.create<StoreOp>(loc, resVec, output, ValueRange{beginIdx, endIdx});
}

Expand Down Expand Up @@ -71,8 +78,15 @@ void calcAndStoreFMAwTailProcessing(OpBuilder &builder, Location loc,
[&](OpBuilder &builder, Location loc) {
Value outputVec = builder.create<LoadOp>(loc, vecType, output,
ValueRange{beginIdx, endIdx});
Value resVec =
builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
Value resVec = {};
auto elemTy = vecType.getElementType();
if (elemTy.isF32()) {
resVec = builder.create<vector::FMAOp>(loc, inputVec, kernelVec, outputVec);
} else if (elemTy.isInteger(32)) {
Value mulVec = builder.create<arith::MulIOp>(loc, inputVec, kernelVec);
resVec = builder.create<arith::AddIOp>(loc, mulVec, outputVec);
}

builder.create<StoreOp>(loc, resVec, output,
ValueRange{beginIdx, endIdx});

Expand All @@ -84,8 +98,15 @@ void calcAndStoreFMAwTailProcessing(OpBuilder &builder, Location loc,
Value outputVec = builder.create<MaskedLoadOp>(
loc, vecType, output, ValueRange{beginIdx, endIdx}, extraElemMask,
zeroPadding);
Value resVec =
builder.create<FMAOp>(loc, inputVec, kernelVec, outputVec);
Value resVec = {};
auto elemTy = vecType.getElementType();
if (elemTy.isF32()) {
resVec = builder.create<vector::FMAOp>(loc, inputVec, kernelVec, outputVec);
} else if (elemTy.isInteger(32)) {
Value mulVec = builder.create<arith::MulIOp>(loc, inputVec, kernelVec);
resVec = builder.create<arith::AddIOp>(loc, mulVec, outputVec);
}
ArtemSkrebkov marked this conversation as resolved.
Show resolved Hide resolved

builder.create<MaskedStoreOp>(loc, output, ValueRange{beginIdx, endIdx},
extraElemMask, resVec);

Expand Down
30 changes: 24 additions & 6 deletions lib/Conversion/LowerDIP/LowerDIPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"

#include "DIP/DIPDialect.h"
Expand Down Expand Up @@ -56,7 +57,7 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
LogicalResult matchAndRewrite(dip::Corr2DOp op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto ctx = op->getContext();
auto *ctx = op->getContext();

// Create constant indices.
Value c0 = rewriter.create<ConstantIndexOp>(loc, 0);
Expand All @@ -72,7 +73,17 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
auto boundaryOptionAttr = op.boundary_option();
Value strideVal = rewriter.create<ConstantIndexOp>(loc, stride);

FloatType f32 = FloatType::getF32(ctx);
auto inElemTy = input.getType().cast<MemRefType>().getElementType();
auto kElemTy = kernel.getType().cast<MemRefType>().getElementType();
auto outElemTy = output.getType().cast<MemRefType>().getElementType();
auto constElemTy = constantValue.getType();
if (inElemTy != kElemTy || kElemTy != outElemTy || outElemTy != constElemTy) {
return op->emitOpError() << "input, kernel, output and constant must have the same element type";
meshtag marked this conversation as resolved.
Show resolved Hide resolved
}
// NB: we can infer element type for all operation to be the same as input
// since we verified that the operand types are the same
auto elemTy = inElemTy;

IntegerType i1 = IntegerType::get(ctx, 1);

// Create DimOp.
Expand All @@ -90,11 +101,18 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
kernelSize};
SmallVector<int64_t, 8> steps{1, 1, stride, 1};

VectorType vectorTy32 = VectorType::get({stride}, f32);
VectorType vectorTy32 = VectorType::get({stride}, elemTy);
VectorType vectorMaskTy = VectorType::get({stride}, i1);

Value zeroPaddingElem =
rewriter.create<ConstantFloatOp>(loc, (APFloat)(float)0, f32);
FloatType f32 = FloatType::getF32(ctx);
IntegerType i32 = IntegerType::get(ctx, 32);
Value zeroPaddingElem = {};
// TODO: extend for other types and add a check for supported types
if (elemTy.isF32()) {
zeroPaddingElem = rewriter.create<ConstantFloatOp>(loc, (APFloat)(float)0, f32);
} else if (elemTy.isInteger(32)) {
zeroPaddingElem = rewriter.create<ConstantIntOp>(loc, 0, i32);
}
Value zeroPadding =
rewriter.create<BroadcastOp>(loc, vectorTy32, zeroPaddingElem);

Expand Down Expand Up @@ -490,7 +508,7 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
builder.create<scf::YieldOp>(loc);
});
});
// Remove the origin convolution operation.

ArtemSkrebkov marked this conversation as resolved.
Show resolved Hide resolved
rewriter.eraseOp(op);
return success();
}
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/DIP/DIPOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

#include "DIP/DIPOps.h"
#include "DIP/DIPDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Support/LogicalResult.h"

#define GET_OP_CLASSES
#include "DIP/DIPOps.cpp.inc"

ArtemSkrebkov marked this conversation as resolved.
Show resolved Hide resolved
ArtemSkrebkov marked this conversation as resolved.
Show resolved Hide resolved
43 changes: 43 additions & 0 deletions tests/Dialect/DIP/correlation2D.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//
// x86
//
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @global_input : memref<3x3xf32> = dense<[[0. , 1. , 2. ],
[10., 11., 12.],
[20., 21., 22.]]>


memref.global "private" @global_identity : memref<3x3xf32> = dense<[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]]>

memref.global "private" @global_output : memref<3x3xf32> = dense<[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]>
func.func private @printMemrefF32(memref<*xf32>)

func.func @main() -> i32 {
%input = memref.get_global @global_input : memref<3x3xf32>
%identity = memref.get_global @global_identity : memref<3x3xf32>
%output = memref.get_global @global_output : memref<3x3xf32>

%x = arith.constant 1 : index
%y = arith.constant 1 : index
%c = arith.constant 0. : f32
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %x, %x, %c : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, index, index, f32

%printed_output = memref.cast %output : memref<3x3xf32> to memref<*xf32>
call @printMemrefF32(%printed_output) : (memref<*xf32>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
// CHECK{LITERAL}: [[0, 1, 2],
// CHECK{LITERAL}: [10, 11, 12],
// CHECK{LITERAL}: [20, 21, 22]]

%ret = arith.constant 0 : i32
return %ret : i32
}
42 changes: 42 additions & 0 deletions tests/Dialect/DIP/correlation2D_i32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//
// x86
//
// RUN: buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=i32 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @global_input : memref<3x3xi32> = dense<[[0 , 1 , 2 ],
[10, 11, 12],
[20, 21, 22]]>


memref.global "private" @global_identity : memref<3x3xi32> = dense<[[0, 0, 0],
[0, 1, 0],
[0, 0, 0]]>

memref.global "private" @global_output : memref<3x3xi32> = dense<[[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]>
func.func private @printMemrefI32(memref<*xi32>)

func.func @main() -> i32 {
%input = memref.get_global @global_input : memref<3x3xi32>
%identity = memref.get_global @global_identity : memref<3x3xi32>
%output = memref.get_global @global_output: memref<3x3xi32>

%x = arith.constant 1 : index
%y = arith.constant 1 : index
%c = arith.constant 0 : i32
dip.corr_2d <CONSTANT_PADDING> %input, %identity, %output, %x, %x, %c : memref<3x3xi32>, memref<3x3xi32>, memref<3x3xi32>, index, index, i32

%printed_output = memref.cast %output : memref<3x3xi32> to memref<*xi32>
call @printMemrefI32(%printed_output) : (memref<*xi32>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 2 offset = 0 sizes = \[3, 3\] strides = \[3, 1\] data =}}
// CHECK{LITERAL}: [[0, 1, 2],
// CHECK{LITERAL}: [10, 11, 12],
// CHECK{LITERAL}: [20, 21, 22]]
%ret = arith.constant 0 : i32
return %ret : i32
}
61 changes: 61 additions & 0 deletions tests/Dialect/DIP/correlation2D_invalid_type.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//
// x86
//
// RUN: not buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
ArtemSkrebkov marked this conversation as resolved.
Show resolved Hide resolved
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts 2>&1 | FileCheck %s

memref.global "private" @global_input_f32 : memref<3x3xf32> = dense<[[0. , 1. , 2. ],
[10., 11., 12.],
[20., 21., 22.]]>


memref.global "private" @global_identity_f32 : memref<3x3xf32> = dense<[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]]>

memref.global "private" @global_output_f32 : memref<3x3xf32> = dense<[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]>

memref.global "private" @global_input_i32 : memref<3x3xi32> = dense<[[0 , 1 , 2 ],
[10, 11, 12],
[20, 21, 22]]>


memref.global "private" @global_identity_i32 : memref<3x3xi32> = dense<[[0, 0, 0],
[0, 1, 0],
[0, 0, 0]]>

memref.global "private" @global_output_i32 : memref<3x3xi32> = dense<[[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]>

func.func @main() -> i32 {
%input_f32 = memref.get_global @global_input_f32 : memref<3x3xf32>
%identity_f32 = memref.get_global @global_identity_f32 : memref<3x3xf32>
%output_f32 = memref.get_global @global_output_f32 : memref<3x3xf32>
%c_f32 = arith.constant 0. : f32

%input_i32 = memref.get_global @global_input_i32 : memref<3x3xi32>
%identity_i32 = memref.get_global @global_identity_i32 : memref<3x3xi32>
%output_i32 = memref.get_global @global_output_i32 : memref<3x3xi32>
%c_i32 = arith.constant 0 : i32

%x = arith.constant 1 : index
%y = arith.constant 1 : index

dip.corr_2d <CONSTANT_PADDING> %input_i32, %identity_f32, %output_f32, %x, %x, %c_f32 : memref<3x3xi32>, memref<3x3xf32>, memref<3x3xf32>, index, index, f32
// CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type

dip.corr_2d <CONSTANT_PADDING> %input_f32, %identity_i32, %output_f32, %x, %x, %c_f32 : memref<3x3xf32>, memref<3x3xi32>, memref<3x3xf32>, index, index, f32
// CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type

dip.corr_2d <CONSTANT_PADDING> %input_f32, %identity_f32, %output_i32, %x, %x, %c_f32 : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xi32>, index, index, f32
// CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type

dip.corr_2d <CONSTANT_PADDING> %input_f32, %identity_f32, %output_f32, %x, %x, %c_i32 : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, index, index, i32
// CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type

meshtag marked this conversation as resolved.
Show resolved Hide resolved
%ret = arith.constant 0 : i32
return %ret : i32
}
6 changes: 5 additions & 1 deletion tests/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@
'buddy-opt',
'buddy-translate',
'buddy-container-test',
'buddy-audio-container-test'
'buddy-audio-container-test',
'mlir-cpu-runner',
]
tools.extend([
ToolSubst('%mlir_runner_utils_dir', config.mlir_runner_utils_dir, unresolved='ignore'),
])

if config.buddy_enable_opencv == "ON":
tools.append('buddy-image-container-test')
Expand Down
1 change: 1 addition & 0 deletions tests/lit.site.cfg.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ config.host_arch = "@HOST_ARCH@"
config.buddy_src_root = "@CMAKE_SOURCE_DIR@"
config.buddy_obj_root = "@CMAKE_BINARY_DIR@"
config.buddy_enable_opencv = "@BUDDY_ENABLE_OPENCV@"
config.mlir_runner_utils_dir = "@LLVM_LIBS_DIR@"

# Support substitution of the tools_dir with user parameters. This is
# used when we can't determine the tool dir at configuration time.
Expand Down