Skip to content

Commit

Permalink
Merge pull request #388 from Xilinx/tiagot.constant_folding_tosa_slice
Browse files Browse the repository at this point in the history
feat: implement constant folding for tosa.slice
  • Loading branch information
ttjost authored Oct 15, 2024
2 parents b04eab8 + 83bdfaf commit 08bb427
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 0 deletions.
128 changes: 128 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,133 @@ struct TosaFoldConstantPad : public TosaFoldConstantBase<tosa::PadOp> {
}
};

template <typename BaseType, typename RangeT>
void sliceArray(ShapedType inputType, RangeT inputValues,
llvm::ArrayRef<int64_t> startValues, ShapedType outputType,
SmallVector<BaseType> &outputValues) {

auto outputShape = outputType.getShape();
auto inputShape = inputType.getShape();

int64_t rank = inputType.getRank();

// Implements the logic from
// https://www.mlplatform.org/tosa/tosa_spec.html#_slice
for (size_t outIndex = 0, e = outputValues.size(); outIndex < e; ++outIndex) {
auto indexInTarget = offsetToIndex(outputShape, outIndex);

for (int64_t i = 0; i < rank; ++i) {
indexInTarget[i] = indexInTarget[i] + startValues[i];
}

auto inputIndexOffset = indexToOffset(inputShape, indexInTarget);
outputValues[outIndex] = inputValues[inputIndexOffset];
}
}

template <typename BaseType>
DenseElementsAttr sliceType(ElementsAttr attr, ShapedType inputType,
llvm::ArrayRef<int64_t> start,
ShapedType outputType) {

auto inputValues = attr.getValues<BaseType>();
SmallVector<BaseType> outputValues(outputType.getNumElements(),
*std::begin(inputValues));
sliceArray<BaseType>(inputType, inputValues, start, outputType, outputValues);
return DenseElementsAttr::get(outputType,
llvm::ArrayRef<BaseType>(outputValues));
}

template <typename BaseType>
DenseElementsAttr sliceTypeRaw(ElementsAttr attr, ShapedType inputType,
llvm::ArrayRef<int64_t> start,
ShapedType outputType) {

ArrayRef<BaseType> inputValues =
cast<DenseIntOrFPElementsAttr>(attr).getNonSplatRawData<BaseType>();

SmallVector<BaseType> outputValues;
outputValues.resize_for_overwrite(outputType.getNumElements());
sliceArray<BaseType>(inputType, inputValues, start, outputType, outputValues);

ArrayRef rawOutputValues(reinterpret_cast<const char *>(outputValues.data()),
outputValues.size() * sizeof(BaseType));
return DenseElementsAttr::getFromRawBuffer(outputType, rawOutputValues);
}

DenseElementsAttr slice(ShapedType inputType, ElementsAttr inputValues,
llvm::ArrayRef<int64_t> start, ShapedType outputType) {

auto baseType = inputType.getElementType();

if (inputValues.isSplat()) {
if (isa<IntegerType>(baseType))
return DenseElementsAttr::get(outputType,
inputValues.getSplatValue<APInt>());
return DenseElementsAttr::get(outputType,
inputValues.getSplatValue<APFloat>());
}

// Handle possible integer types
if (auto intType = dyn_cast<IntegerType>(baseType)) {
switch (intType.getWidth()) {
case 1:
// i1 has special alignment which is not handled by sliceTypeRaw.
return sliceType<bool>(inputValues, inputType, start, outputType);
case 8:
return sliceTypeRaw<uint8_t>(inputValues, inputType, start, outputType);
case 16:
return sliceTypeRaw<uint16_t>(inputValues, inputType, start, outputType);
case 32:
return sliceTypeRaw<uint32_t>(inputValues, inputType, start, outputType);
case 64:
return sliceTypeRaw<uint64_t>(inputValues, inputType, start, outputType);
default:
return sliceType<APInt>(inputValues, inputType, start, outputType);
}
}

// Handle possible float types
if (baseType.isF32()) {
return sliceTypeRaw<uint32_t>(inputValues, inputType, start, outputType);
}
if (baseType.isF64()) {
return sliceTypeRaw<uint64_t>(inputValues, inputType, start, outputType);
}
if (baseType.isBF16()) {
return sliceTypeRaw<uint16_t>(inputValues, inputType, start, outputType);
}
return sliceType<APFloat>(inputValues, inputType, start, outputType);
}

struct TosaFoldConstantSlice : public TosaFoldConstantBase<tosa::SliceOp> {
using TosaFoldConstantBase::TosaFoldConstantBase;

LogicalResult matchAndRewrite(tosa::SliceOp op,
PatternRewriter &rewriter) const override {
auto outputType = cast<ShapedType>(op.getType());
// TOSA doesn't support quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();

auto start = op.getStart();
auto input = op.getInput();
ElementsAttr inputValues;
if (!matchPattern(input, m_Constant(&inputValues)))
return failure();

// Only fold op with multiple users if foldSplatOrSingleUseOnly is false.
if (!llvm::hasSingleElement(input.getDefiningOp()->getUsers()) &&
foldSplatOrSingleUseOnly)
return failure();

auto resultAttr = slice(input.getType(), inputValues, start, outputType);
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);

return success();
}
};

template <typename BaseType, typename RangeT>
void tileArray(ShapedType inputType, RangeT inputValues, ShapedType outputType,
SmallVector<BaseType> &outputValues) {
Expand Down Expand Up @@ -1991,6 +2118,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
patterns.add<TosaFoldConstantMinimum>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMaximum>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantPad>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantSlice>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly);
if (options.enableTileFolding)
patterns.add<TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly);
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Tosa/constant-slice-multi-user.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="fold-splat-or-single-use-only=0" %s | FileCheck %s
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="fold-splat-or-single-use-only=1" %s | FileCheck %s --check-prefix=ONLY-SINGLE-USE-CHECK

// CHECK-LABEL: @slice_bf16
func.func @slice_bf16() -> (tensor<3x3xbf16>, tensor<3x2xbf16>) {
// CHECK-DAG: "tosa.const"() <{value = dense<{{\[\[}}3.000000e+00, 4.000000e+00, 5.000000e+00], [6.000000e+00, 7.000000e+00, 8.000000e+00], [9.000000e+00, 1.000000e+01, 1.100000e+01]]>
// CHECK-DAG: "tosa.const"() <{value = dense<{{\[\[}}4.000000e+00, 5.000000e+00], [7.000000e+00, 8.000000e+00], [1.000000e+01, 1.100000e+01]]>
// ONLY-SINGLE-USE-CHECK: tosa.slice
%0 = "tosa.const"() {value = dense<[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]> : tensor<3x3xbf16>} : () -> tensor<3x3xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 2>, start = array<i64: 0, 1>} : (tensor<3x3xbf16>) -> tensor<3x2xbf16>
return %0, %1 : tensor<3x3xbf16>, tensor<3x2xbf16>
}

172 changes: 172 additions & 0 deletions mlir/test/Dialect/Tosa/constant-slice.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s

// CHECK-LABEL: @slice_int8
func.func @slice_int8() -> (tensor<1x1xi8>) {
// CHECK: "tosa.const"() <{value = dense<3>
%0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi8>} : () -> tensor<2x2xi8>
%1 = "tosa.slice"(%0){size = array<i64: 1, 1>, start = array<i64: 0, 0>} : (tensor<2x2xi8>) -> tensor<1x1xi8>
return %1 : tensor<1x1xi8>
}

func.func @slice_int16() -> (tensor<2x1xi16>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}3], [5]]>
%0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi16>} : () -> tensor<2x2xi16>
%1 = "tosa.slice"(%0){size = array<i64: 2, 1>, start = array<i64: 0, 0>} : (tensor<2x2xi16>) -> tensor<2x1xi16>
return %1 : tensor<2x1xi16>
}

// CHECK-LABEL: @slice_int32
func.func @slice_int32() -> (tensor<2x1xi32>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}4], [6]]>
%0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
%1 = "tosa.slice"(%0){size = array<i64: 2, 1>, start = array<i64: 0, 1>} : (tensor<2x2xi32>) -> tensor<2x1xi32>
return %1 : tensor<2x1xi32>
}

// CHECK-LABEL: @slice_int32_default_value
func.func @slice_int32_default_value() -> (tensor<3x1xi32>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}3], [6], [9]]>
%0 = "tosa.const"() {value = dense<[[3, 4, 5], [6, 7, 8], [9, 10, 11]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1>, start = array<i64: 0, 0>} : (tensor<3x3xi32>) -> tensor<3x1xi32>
return %1 : tensor<3x1xi32>
}

// CHECK-LABEL: @slice_bf16_default_value
func.func @slice_bf16_default_value() -> (tensor<3x2xbf16>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}4.000000e+00, 5.000000e+00], [7.000000e+00, 8.000000e+00], [1.000000e+01, 1.100000e+01]]>
%0 = "tosa.const"() {value = dense<[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]> : tensor<3x3xbf16>} : () -> tensor<3x3xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 2>, start = array<i64: 0, 1>} : (tensor<3x3xbf16>) -> tensor<3x2xbf16>
return %1 : tensor<3x2xbf16>
}

// -----

// Following tests are all done with the following tensor, and different configurations:
// [[[1.0 , 2.25 , 3.50 , 4.75],
// [ 5.0 , 6.25 , 7.50 , 8.75]],
// [[ 13.0 , 14.25 , 15.50 , 16.75 ],
// [ 17.0 , 18.25 , 19.50 , 20.75]],
// [[-1.0 , -2.25 , -3.50 , -4.75],
// [ -5.0 , -6.25 , -7.50 , -8.75]],
// [[ -13.0 , -14.25 , -15.50 , -16.75 ],
// [ -17.0 , -18.25 , -19.50 , -20.75]]]

// Should produce
// 1.0, 2.25, 3.50, 4.75,
// 13.0, 14.25, 15.50, 16.75,
// -1.0, -2.25, -3.50, -4.75,
// -13.0, -14.25, -15.50, -16.75
func.func @slice_bf16_dim_1_start_zero() -> (tensor<4x1x4xbf16>) {
// CHECK-LABEL: @slice_bf16_dim_1_start_zero
// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00, 4.750000e+00
// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01, 1.675000e+01
// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00, -4.750000e+00
// CHECK-SAME: -1.300000e+01, -1.425000e+01, -1.550000e+01, -1.675000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 4, 1, 4>, start = array<i64: 0, 0, 0>} : (tensor<4x2x4xbf16>) -> tensor<4x1x4xbf16>
return %1 : tensor<4x1x4xbf16>
}

// Should produce
// 1.0, 2.25, 3.50, 4.75,
// 13.0, 14.25, 15.50, 16.75,
// -1.0, -2.25, -3.50, -4.75,
// -13.0, -14.25, -15.50, -16.75
func.func @slice_f16_dim_1_start_zero() -> (tensor<4x1x4xf16>) {
// CHECK-LABEL: @slice_f16_dim_1_start_zero
// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00, 4.750000e+00
// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01, 1.675000e+01
// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00, -4.750000e+00
// CHECK-SAME: -1.300000e+01, -1.425000e+01, -1.550000e+01, -1.675000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16>
%1 = "tosa.slice"(%0){size = array<i64: 4, 1, 4>, start = array<i64: 0, 0, 0>} : (tensor<4x2x4xf16>) -> tensor<4x1x4xf16>
return %1 : tensor<4x1x4xf16>
}

// Should produce
// 5.0, 6.25, 7.50, 8.75
// 17.0, 18.25, 19.50, 20.75
// -5.0, -6.25, -7.50, -8.75
// -17.0, -18.25, -19.50, -20.75
func.func @slice_bf16_start_dim_1_start_one() -> (tensor<4x1x4xbf16>) {
// CHECK-LABEL: @slice_bf16_start_dim_1_start_one
// CHECK: 5.000000e+00, 6.250000e+00, 7.500000e+00, 8.750000e+00
// CHECK-SAME: 1.700000e+01, 1.825000e+01, 1.950000e+01, 2.075000e+01
// CHECK-SAME: -5.000000e+00, -6.250000e+00, -7.500000e+00, -8.750000e+00
// CHECK-SAME: -1.700000e+01, -1.825000e+01, -1.950000e+01, -2.075000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 4, 1, 4>, start = array<i64: 0, 1, 0>} : (tensor<4x2x4xbf16>) -> tensor<4x1x4xbf16>
return %1 : tensor<4x1x4xbf16>
}

// Should produce
// 5.0, 6.25, 7.50, 8.75
// 17.0, 18.25, 19.50, 20.75
// -5.0, -6.25, -7.50, -8.75
// -17.0, -18.25, -19.50, -20.75
func.func @slice_f16_start_dim_1_start_one() -> (tensor<4x1x4xf16>) {
// CHECK-LABEL: @slice_f16_start_dim_1_start_one
// CHECK: 5.000000e+00, 6.250000e+00, 7.500000e+00, 8.750000e+00
// CHECK-SAME: 1.700000e+01, 1.825000e+01, 1.950000e+01, 2.075000e+01
// CHECK-SAME: -5.000000e+00, -6.250000e+00, -7.500000e+00, -8.750000e+00
// CHECK-SAME: -1.700000e+01, -1.825000e+01, -1.950000e+01, -2.075000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16>
%1 = "tosa.slice"(%0){size = array<i64: 4, 1, 4>, start = array<i64: 0, 1, 0>} : (tensor<4x2x4xf16>) -> tensor<4x1x4xf16>
return %1 : tensor<4x1x4xf16>
}

// Should produce
// 1.0, 2.25, 3.50
// 13.0, 14.25, 15.50
// -1.0, -2.25, -3.50
func.func @slice_bf16_start_zero_multiple_dims() -> (tensor<3x1x3xbf16>) {
// CHECK-LABEL: @slice_bf16_start_zero_multiple_dims
// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00
// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01
// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1, 3>, start = array<i64: 0, 0, 0>} : (tensor<4x2x4xbf16>) -> tensor<3x1x3xbf16>
return %1 : tensor<3x1x3xbf16>
}

// Should produce
// 1.0, 2.25, 3.50
// 13.0, 14.25, 15.50
// -1.0, -2.25, -3.50
func.func @slice_f16_start_zero_multiple_dims() -> (tensor<3x1x3xf16>) {
// CHECK-LABEL: @slice_f16_start_zero_multiple_dims
// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00
// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01
// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1, 3>, start = array<i64: 0, 0, 0>} : (tensor<4x2x4xf16>) -> tensor<3x1x3xf16>
return %1 : tensor<3x1x3xf16>
}

// Produces
// 18.25, 19.50, 20.75
// -6.25, -7.50, -8.75
// -18.25, -19.50, -20.75
func.func @slice_bf16_start_non_zero_multiple_dims() -> (tensor<3x1x3xbf16>) {
// CHECK-LABEL: @slice_bf16_start_non_zero_multiple_dims
// CHECK: 1.825000e+01, 1.950000e+01, 2.075000e+01
// CHECK-SAME: -6.250000e+00, -7.500000e+00, -8.750000e+00
// CHECK-SAME: -1.825000e+01, -1.950000e+01, -2.075000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1, 3>, start = array<i64: 1, 1, 1>} : (tensor<4x2x4xbf16>) -> tensor<3x1x3xbf16>
return %1 : tensor<3x1x3xbf16>
}

// Produces
// 18.25, 19.50, 20.75
// -6.25, -7.50, -8.75
// -18.25, -19.50, -20.75
func.func @slice_f16_start_non_zero_multiple_dims() -> (tensor<3x1x3xf16>) {
// CHECK-LABEL: @slice_f16_start_non_zero_multiple_dims
// CHECK: 1.825000e+01, 1.950000e+01, 2.075000e+01
// CHECK-SAME: -6.250000e+00, -7.500000e+00, -8.750000e+00
// CHECK-SAME: -1.825000e+01, -1.950000e+01, -2.075000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1, 3>, start = array<i64: 1, 1, 1>} : (tensor<4x2x4xf16>) -> tensor<3x1x3xf16>
return %1 : tensor<3x1x3xf16>
}

0 comments on commit 08bb427

Please sign in to comment.