Skip to content

Commit

Permalink
[mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm (#96181)
Browse files Browse the repository at this point in the history
Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)

Reviewers: stellaraccident, ftynse, Max191, GeorgeARM, cxy-1993, nicolasvasilache, MaheshRavishankar, dcaballe, rengolin

Reviewed By: ftynse, Max191, stellaraccident

Pull Request: #96181
  • Loading branch information
Hsiangkai authored Jul 10, 2024
1 parent 015526b commit 7d246e8
Show file tree
Hide file tree
Showing 9 changed files with 943 additions and 0 deletions.
117 changes: 117 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,121 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}

def Linalg_WinogradFilterTransformOp :
Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
let summary = "Winograd filter transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
matrix multiply. Before the matrix multiply, it will convert filter and
input into a format suitable for batched matrix multiply. After the matrix
multiply, it will convert output to the final result tensor.

The algorithm F(m x m, r x r) is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

The size of output Y is m x m. The size of filter g is r x r. The size of
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
transformation matrices.

This operator is defined to represent the high level concept of filter
transformation (G x g x G^T) in the Winograd Conv2D algorithm.
}];

let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
TensorRankOf<[AnyType], [4]>:$output,
I64Attr:$m,
I64Attr:$r
);

let results = (outs TensorRankOf<[AnyType], [4]>:$result);
let assemblyFormat = [{
attr-dict
`m` `(` $m `)`
`r` `(` $r `)`
`ins` `(` $filter `:` type($filter) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let hasVerifier = 1;
}

def Linalg_WinogradInputTransformOp :
Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
let summary = "Winograd input transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
matrix multiply. Before the matrix multiply, it will convert filter and
input into a format suitable for batched matrix multiply. After the matrix
multiply, it will convert output to the final result tensor.

The algorithm F(m x m, r x r) is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

The size of output Y is m x m. The size of filter g is r x r. The size of
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
transformation matrices.

This operator is defined to represent the high level concept of input
transformation (B^T x d x B) in the Winograd Conv2D algorithm.
}];

let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
TensorRankOf<[AnyType], [6]>:$output,
I64Attr:$m,
I64Attr:$r
);

let results = (outs TensorRankOf<[AnyType], [6]>:$result);
let assemblyFormat = [{
attr-dict
`m` `(` $m `)`
`r` `(` $r `)`
`ins` `(` $input `:` type($input) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let hasVerifier = 1;
}

def Linalg_WinogradOutputTransformOp :
Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
let summary = "Winograd output transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
matrix multiply. Before the matrix multiply, it will convert filter and
input into a format suitable for batched matrix multiply. After the matrix
multiply, it will convert output to the final result tensor.

The algorithm F(m x m, r x r) is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

The size of output Y is m x m. The size of filter g is r x r. The size of
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
transformation matrices.

This operator is defined to represent the high level concept of output
transformation (A^T x y x A) in the Winograd Conv2D algorithm.
}];

let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
TensorRankOf<[AnyType], [4]>:$output,
I64Attr:$m,
I64Attr:$r
);

let results = (outs TensorRankOf<[AnyType], [4]>:$result);
let assemblyFormat = [{
attr-dict
`m` `(` $m `)`
`r` `(` $r `)`
`ins` `(` $value `:` type($value) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let hasVerifier = 1;
}

#endif // LINALG_OPS
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1735,6 +1735,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);

/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
int64_t r);

/// Adds patterns that reduce the rank of named contraction ops that have
/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
Expand Down
116 changes: 116 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2739,6 +2739,122 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
return SmallVector<Value>{result};
}

//===----------------------------------------------------------------------===//
// WinogradFilterTransformOp
//===----------------------------------------------------------------------===//

LogicalResult WinogradFilterTransformOp::verify() {
auto filterType = cast<ShapedType>(getFilter().getType());
ArrayRef<int64_t> filterShape = filterType.getShape();
int64_t filterH = filterShape[1];
int64_t filterW = filterShape[2];
int64_t r = getR();
int64_t m = getM();

if (filterH != r && filterH != 1)
return emitOpError("expect filter height either equals to r or 1");
if (filterW != r && filterW != 1)
return emitOpError("expect filter width either equals to r or 1");
if (filterH == 1 && filterW == 1)
return emitOpError("expect either filter height or width equals to r");

SmallVector<int64_t> expectedOutputShape;
expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
expectedOutputShape.push_back(filterShape[3]);
expectedOutputShape.push_back(filterShape[0]);

auto outputType = cast<ShapedType>(getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
return success();
}

//===----------------------------------------------------------------------===//
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//

LogicalResult WinogradInputTransformOp::verify() {
auto inputType = cast<ShapedType>(getInput().getType());
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputH = inputShape[1];
int64_t inputW = inputShape[2];
int m = getM();
int r = getR();
int64_t tileSize = m + r - 1;
bool leftTransform = inputH != 1;
bool rightTransform = inputW != 1;

SmallVector<int64_t> expectedOutputShape(6, inputH);
if (ShapedType::isDynamic(inputH)) {
expectedOutputShape[0] = tileSize;
expectedOutputShape[2] = ShapedType::kDynamic;
} else {
expectedOutputShape[0] = leftTransform ? tileSize : 1;
expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1;
}
if (ShapedType::isDynamic(inputW)) {
expectedOutputShape[1] = tileSize;
expectedOutputShape[3] = ShapedType::kDynamic;
} else {
expectedOutputShape[1] = rightTransform ? tileSize : 1;
expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1;
}
expectedOutputShape[4] = inputShape[0];
expectedOutputShape[5] = inputShape[3];

auto outputType = cast<ShapedType>(getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
return success();
}

//===----------------------------------------------------------------------===//
// WinogradOutputTransformOp
//===----------------------------------------------------------------------===//

LogicalResult WinogradOutputTransformOp::verify() {
auto valueType = cast<ShapedType>(getValue().getType());
ArrayRef<int64_t> valueShape = valueType.getShape();
int64_t valueH = valueShape[0];
int64_t valueW = valueShape[1];
int64_t valueTileH = valueShape[2];
int64_t valueTileW = valueShape[3];
int m = getM();
int r = getR();
bool leftTransform = valueH != 1;
bool rightTransform = valueW != 1;

SmallVector<int64_t> expectedOutputShape(4, valueH);
if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
expectedOutputShape[1] = ShapedType::kDynamic;
} else {
if (valueH != (leftTransform ? m + r - 1 : 1))
return emitOpError("expect input height equals to input tile size");
expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH;
}
if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
expectedOutputShape[2] = ShapedType::kDynamic;
} else {
if (valueW != (rightTransform ? m + r - 1 : 1))
return emitOpError("expect input width equals to input tile size");
expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW;
}
expectedOutputShape[0] = valueShape[4];
expectedOutputShape[3] = valueShape[5];

auto outputType = cast<ShapedType>(getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
return success();
}

//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Transforms.cpp
TransposeConv2D.cpp
Vectorization.cpp
WinogradConv2D.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
Expand Down
Loading

0 comments on commit 7d246e8

Please sign in to comment.