Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FXML-1930] Implement mul folding #27

Merged
merged 3 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ void populateTosaFoldConstantClampPatterns(MLIRContext *ctx,
void populateTosaFoldConstantCastPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
bool enableIntCastFolding);
void populateTosaFoldConstantMulPatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantPowPatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaFoldConstantAdd.cpp
TosaFoldConstantCast.cpp
TosaFoldConstantClamp.cpp
TosaFoldConstantMul.cpp
TosaFoldConstantPow.cpp
TosaFoldConstantReciprocal.cpp
TosaFoldConstantRSQRT.cpp
Expand Down
118 changes: 118 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//===- TosaFoldConstantMul.cpp --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold TOSA Mul operation on constant data
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include <llvm/ADT/APInt.h>
#include <mlir/Support/LogicalResult.h>

using namespace mlir;
using namespace mlir::tosa;

namespace {

struct TosaFoldConstantMul : public OpRewritePattern<MulOp> {

using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(MulOp mulOp,
PatternRewriter &rewriter) const override {
if (mulOp.getShift() > 0) {
return rewriter.notifyMatchFailure(
mulOp, "Non-zero shift folding is currently not implemented.");
}

auto leftOp = mulOp.getInput1();
auto rightOp = mulOp.getInput2();

// Check if both tensors are constant
auto rhsIsConstantCheck =
notifyIfNoTosaDenseConstantTensor(leftOp, mulOp, rewriter);
if (failed(rhsIsConstantCheck)) {
return rhsIsConstantCheck;
}
auto lhsIsConstantCheck =
notifyIfNoTosaDenseConstantTensor(rightOp, mulOp, rewriter);
if (failed(lhsIsConstantCheck)) {
return lhsIsConstantCheck;
}

// Extract the tensor values
DenseElementsAttr lhsValues;
matchPattern(leftOp, m_Constant(&lhsValues));

DenseElementsAttr rhsValues;
matchPattern(rightOp, m_Constant(&rhsValues));

if (!constantBinaryOpShouldBeFolded(mulOp, lhsValues, rhsValues)) {
return rewriter.notifyMatchFailure(
mulOp, "Currently, muls will only be folded if this requires only "
"little additional memory usage.");
}

DenseElementsAttr newTensor;

auto lhsElemType = leftOp.getType().getElementType();
auto rhsElemType = rightOp.getType().getElementType();
assert(lhsElemType == rhsElemType);

auto resultType = mulOp.getType();
auto resultElementType = resultType.getElementType();
if (isa<IntegerType>(lhsElemType)) {
assert(isa<IntegerType>(rhsElemType) &&
isa<IntegerType>(resultElementType));
auto resultElementWidth = resultElementType.getIntOrFloatBitWidth();
assert(resultElementWidth >= lhsElemType.getIntOrFloatBitWidth() &&
"The multiplication is expected to have an at least as big output "
"as input type");

// Compute the multiplication and track if an overflow occurred to enable
// emitting a warning
bool mulOverflowed = false;
auto intMulFun = [&resultElementWidth, &mulOverflowed](
const APInt &first, const APInt &second) {
bool didOverflow;
auto res = first.sext(resultElementWidth)
.smul_ov(second.sext(resultElementWidth), didOverflow);
mulOverflowed |= didOverflow;
return res;
};
newTensor = applyElementWise<APInt, APInt>(lhsValues, rhsValues,
resultType, intMulFun);
if (mulOverflowed) {
mulOp.emitWarning(
"Multiplication did overflow. The results are unspecified.");
}
} else {
assert(isa<FloatType>(lhsElemType) && isa<FloatType>(rhsElemType) &&
isa<FloatType>(resultType.getElementType()));
auto mulFun = [](const APFloat &first, const APFloat &second) {
return first * second;
};
newTensor = applyElementWise<APFloat, APFloat>(lhsValues, rhsValues,
resultType, mulFun);
}
rewriter.replaceOpWithNewOp<ConstOp>(mulOp, newTensor.getType(), newTensor);

return success();
}
};

} // namespace

void mlir::tosa::populateTosaFoldConstantMulPatterns(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TosaFoldConstantMul>(ctx);
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct TosaLayerwiseConstantFoldPass
mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns,
enableIntCastFolding);
mlir::tosa::populateTosaFoldConstantClampPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantMulPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns);
Expand Down
185 changes: 185 additions & 0 deletions mlir/test/Dialect/Tosa/constant-mul-opt.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s

// Float multiplications

// CHECK-LABEL: @mul_fold_float
func.func @mul_fold_float() -> tensor<4xf16> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.32{{.*}}e+03, -1.49{{.*}}e+01, -0.{{0*}}e+00, -0.{{0*}}e+00
// CHECK-NOT: tosa.mul
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[-17.4978, 4.9882, 0.0, -0.0]> :
tensor<4xf16>
} : () -> tensor<4xf16>
%1 = "tosa.const"() {value =
dense<[-132.7, -3.0, -0.0, 5.0]> :
tensor<4xf16>
} : () -> tensor<4xf16>
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16>
return %2 : tensor<4xf16>
}

// CHECK-LABEL: @mul_fold_float_infinity_nan
func.func @mul_fold_float_infinity_nan() -> tensor<7xf32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0x7F800000, 0xFF800000, 0xFF800000, 0x7FC00000, 0xFF800000, 0x7FC00000
// CHECK-NOT: tosa.mul
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000, 0xFF800000]> :
tensor<7xf32>
} : () -> tensor<7xf32>
%1 = "tosa.const"() {value =
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000, 0.0]> :
tensor<7xf32>
} : () -> tensor<7xf32>
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<7xf32>, tensor<7xf32>) -> tensor<7xf32>
return %2 : tensor<7xf32>
}

// CHECK-LABEL: @add_fold_float_overflow
func.func @add_fold_float_overflow() -> tensor<2xf32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000
// CHECK-NOT: tosa.mul
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[3.1e+38, -3.1e+38]> :
tensor<2xf32>
} : () -> tensor<2xf32>
%1 = "tosa.const"() {value =
dense<[2.1e+38, 1.1e+38]> :
tensor<2xf32>
} : () -> tensor<2xf32>
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %2 : tensor<2xf32>
}

// -----
// Int multiplications

// CHECK-LABEL: @mul_fold_int
func.func @mul_fold_int() -> tensor<4xi32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2244, -12, 0, 0
// CHECK-NOT: tosa.mul
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[-17, 4, 0, 0]> :
tensor<4xi32>
} : () -> tensor<4xi32>
%1 = "tosa.const"() {value =
dense<[-132, -3, 0, 5]> :
tensor<4xi32>
} : () -> tensor<4xi32>
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %2 : tensor<4xi32>
}

// CHECK-LABEL: @mul_fold_i8
func.func @mul_fold_i8() -> tensor<4xi32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}204, -12, 0, 0
// CHECK-NOT: tosa.mul
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[-17, 4, -2, 0]> :
tensor<4xi8>
} : () -> tensor<4xi8>
%1 = "tosa.const"() {value =
dense<[-12, -3, 0, 5]> :
tensor<4xi8>
} : () -> tensor<4xi8>
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi32>
return %2 : tensor<4xi32>
}

// CHECK-LABEL: @mul_fold_int_overflow
func.func @mul_fold_int_overflow() -> tensor<4xi32> {
// Don't expect any specific results for the overflowing multiplication, just
// that it is folded.
// CHECK: [[RES:]] ={{.*}}tosa.const
// CHECK-NOT: tosa.mul
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[2147483647, 2147483640, -2147483648, -2147483640]> :
tensor<4xi32>
} : () -> tensor<4xi32>
%1 = "tosa.const"() {value =
dense<[1, 10, 1, 30]> :
tensor<4xi32>
} : () -> tensor<4xi32>
// expected-warning@below {{Multiplication did overflow. The results are unspecified.}}
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %2 : tensor<4xi32>
}

// -----
// self-multiplication

// CHECK-LABEL: @mul_fold_equal_args
func.func @mul_fold_equal_args() -> tensor<3xi32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}289, 16, 0
// CHECK-NOT: tosa.mul
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[-17, 4, 0]> :
tensor<3xi32>
} : () -> tensor<3xi32>
%2 = "tosa.mul"(%0, %0) {shift = 0 : i32} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
return %2 : tensor<3xi32>
}

// -----
// Broadcasted multiplications

// CHECK-LABEL: @mul_fold_int_broadcast_simple
func.func @mul_fold_int_broadcast_simple() -> tensor<3xi32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}204, -48, 0
// CHECK-NOT: tosa.mul
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[-17, 4, 0]> :
tensor<3xi32>
} : () -> tensor<3xi32>
%1 = "tosa.const"() {value =
dense<-12> :
tensor<1xi32>
} : () -> tensor<1xi32>
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
return %2 : tensor<3xi32>
}

// CHECK-LABEL: @mul_fold_int_broadcast_complex
func.func @mul_fold_int_broadcast_complex() -> tensor<3x3xi32> {
// CHECK: [[RES:]] ={{.*}}tosa.const
// CHECK-SAME{LITERAL}: [[204, -119, -68],
// CHECK-SAME{LITERAL}: [-12, 7, 4],
// CHECK-SAME{LITERAL}: [-228, 133, 76]]
// CHECK-NOT: tosa.mul
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[[-17], [1], [19]]> :
tensor<3x1xi32>
} : () -> tensor<3x1xi32>
%1 = "tosa.const"() {value =
dense<[[-12, 7, 4]]> :
tensor<1x3xi32>
} : () -> tensor<1x3xi32>
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<3x1xi32>, tensor<1x3xi32>) -> tensor<3x3xi32>
return %2 : tensor<3x3xi32>
}

// CHECK-LABEL: @mul_fold_int_non_zero_shift
func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> {
// CHECK: [[FIRST:]] ={{.*}}tosa.const
// CHECK-NEXT: [[SECOND:]] ={{.*}}tosa.const
// CHECK-NEXT: [[MUL:]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]]
// CHECK-NEXT: return [[MUL]]
%0 = "tosa.const"() {value =
dense<[-17, 4, 0, 0]> :
tensor<4xi32>
} : () -> tensor<4xi32>
%1 = "tosa.const"() {value =
dense<[-132, -3, 0, 5]> :
tensor<4xi32>
} : () -> tensor<4xi32>
%2 = "tosa.mul"(%0, %1) {shift = 1 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %2 : tensor<4xi32>
}