Skip to content

Commit

Permalink
feat: implement constant folding for tosa.slice
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost committed Oct 14, 2024
1 parent 9d48ee6 commit 9489ae8
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 0 deletions.
128 changes: 128 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,133 @@ struct TosaFoldConstantPad : public TosaFoldConstantBase<tosa::PadOp> {
}
};

template <typename BaseType, typename RangeT>
void sliceArray(ShapedType inputType, RangeT inputValues,
llvm::ArrayRef<int64_t> startValues, ShapedType outputType,
SmallVector<BaseType> &outputValues) {

auto outputShape = outputType.getShape();
auto inputShape = inputType.getShape();

int64_t rank = inputType.getRank();

// Implements the logic from
// https://www.mlplatform.org/tosa/tosa_spec.html#_slice
for (size_t outIndex = 0, e = outputValues.size(); outIndex < e; ++outIndex) {
auto indexInTarget = offsetToIndex(outputShape, outIndex);

for (int64_t i = 0; i < rank; ++i) {
indexInTarget[i] = indexInTarget[i] + startValues[i];
}

auto inputIndexOffset = indexToOffset(inputShape, indexInTarget);
outputValues[outIndex] = inputValues[inputIndexOffset];
}
}

template <typename BaseType>
DenseElementsAttr sliceType(ElementsAttr attr, ShapedType inputType,
llvm::ArrayRef<int64_t> start,
ShapedType outputType) {

auto inputValues = attr.getValues<BaseType>();
SmallVector<BaseType> outputValues(outputType.getNumElements(),
*std::begin(inputValues));
sliceArray<BaseType>(inputType, inputValues, start, outputType, outputValues);
return DenseElementsAttr::get(outputType,
llvm::ArrayRef<BaseType>(outputValues));
}

template <typename BaseType>
DenseElementsAttr sliceTypeRaw(ElementsAttr attr, ShapedType inputType,
llvm::ArrayRef<int64_t> start,
ShapedType outputType) {

ArrayRef<BaseType> inputValues =
cast<DenseIntOrFPElementsAttr>(attr).getNonSplatRawData<BaseType>();

SmallVector<BaseType> outputValues;
outputValues.resize_for_overwrite(outputType.getNumElements());
sliceArray<BaseType>(inputType, inputValues, start, outputType, outputValues);

ArrayRef rawOutputValues(reinterpret_cast<const char *>(outputValues.data()),
outputValues.size() * sizeof(BaseType));
return DenseElementsAttr::getFromRawBuffer(outputType, rawOutputValues);
}

DenseElementsAttr slice(ShapedType inputType, ElementsAttr inputValues,
llvm::ArrayRef<int64_t> start, ShapedType outputType) {

auto baseType = inputType.getElementType();

if (inputValues.isSplat()) {
if (isa<IntegerType>(baseType))
return DenseElementsAttr::get(outputType,
inputValues.getSplatValue<APInt>());
return DenseElementsAttr::get(outputType,
inputValues.getSplatValue<APFloat>());
}

// Handle possible integer types
if (auto intType = dyn_cast<IntegerType>(baseType)) {
switch (intType.getWidth()) {
case 1:
// i1 has special alignment which is not handled by sliceTypeRaw.
return sliceType<bool>(inputValues, inputType, start, outputType);
case 8:
return sliceTypeRaw<uint8_t>(inputValues, inputType, start, outputType);
case 16:
return sliceTypeRaw<uint16_t>(inputValues, inputType, start, outputType);
case 32:
return sliceTypeRaw<uint32_t>(inputValues, inputType, start, outputType);
case 64:
return sliceTypeRaw<uint64_t>(inputValues, inputType, start, outputType);
default:
return sliceType<APInt>(inputValues, inputType, start, outputType);
}
}

// Handle possible float types
if (baseType.isF32()) {
return sliceTypeRaw<uint32_t>(inputValues, inputType, start, outputType);
}
if (baseType.isF64()) {
return sliceTypeRaw<uint64_t>(inputValues, inputType, start, outputType);
}
if (baseType.isBF16()) {
return sliceTypeRaw<uint16_t>(inputValues, inputType, start, outputType);
}
return sliceType<APFloat>(inputValues, inputType, start, outputType);
}

struct TosaFoldConstantSlice : public TosaFoldConstantBase<tosa::SliceOp> {
using TosaFoldConstantBase::TosaFoldConstantBase;

LogicalResult matchAndRewrite(tosa::SliceOp op,
PatternRewriter &rewriter) const override {
auto outputType = cast<ShapedType>(op.getType());
// TOSA doesn't support quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();

auto start = op.getStart();
auto input = op.getInput();
ElementsAttr inputValues;
if (!matchPattern(input, m_Constant(&inputValues)))
return failure();

// Only fold op with multiple users if foldSplatOrSingleUseOnly is false.
if (!llvm::hasSingleElement(input.getDefiningOp()->getUsers()) &&
foldSplatOrSingleUseOnly)
return failure();

auto resultAttr = slice(input.getType(), inputValues, start, outputType);
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);

return success();
}
};

template <typename BaseType, typename RangeT>
void tileArray(ShapedType inputType, RangeT inputValues, ShapedType outputType,
SmallVector<BaseType> &outputValues) {
Expand Down Expand Up @@ -1991,6 +2118,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
patterns.add<TosaFoldConstantMinimum>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMaximum>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantPad>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantSlice>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly);
if (options.enableTileFolding)
patterns.add<TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly);
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/Tosa/constant-slice.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s

// CHECK-LABEL: @slice_int8
func.func @slice_int8() -> (tensor<1x1xi8>) {
// CHECK: "tosa.const"() <{value = dense<3>
%0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi8>} : () -> tensor<2x2xi8>
%1 = "tosa.slice"(%0){size = array<i64: 1, 1>, start = array<i64: 0, 0>} : (tensor<2x2xi8>) -> tensor<1x1xi8>
return %1 : tensor<1x1xi8>
}

func.func @slice_int16() -> (tensor<2x1xi16>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}3], [5]]>
%0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi16>} : () -> tensor<2x2xi16>
%1 = "tosa.slice"(%0){size = array<i64: 2, 1>, start = array<i64: 0, 0>} : (tensor<2x2xi16>) -> tensor<2x1xi16>
return %1 : tensor<2x1xi16>
}

// CHECK-LABEL: @slice_int32
func.func @slice_int32() -> (tensor<2x1xi32>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}4], [6]]>
%0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
%1 = "tosa.slice"(%0){size = array<i64: 2, 1>, start = array<i64: 0, 1>} : (tensor<2x2xi32>) -> tensor<2x1xi32>
return %1 : tensor<2x1xi32>
}

// CHECK-LABEL: @slice_int32_default_value
func.func @slice_int32_default_value() -> (tensor<3x1xi32>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}3], [6], [9]]>
%0 = "tosa.const"() {value = dense<[[3, 4, 5], [6, 7, 8], [9, 10, 11]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1>, start = array<i64: 0, 0>} : (tensor<3x3xi32>) -> tensor<3x1xi32>
return %1 : tensor<3x1xi32>
}

// CHECK-LABEL: @slice_bf16_default_value
func.func @slice_bf16_default_value() -> (tensor<3x2xbf16>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}4.000000e+00, 5.000000e+00], [7.000000e+00, 8.000000e+00], [1.000000e+01, 1.100000e+01]]>
%0 = "tosa.const"() {value = dense<[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]> : tensor<3x3xbf16>} : () -> tensor<3x3xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 2>, start = array<i64: 0, 1>} : (tensor<3x3xbf16>) -> tensor<3x2xbf16>
return %1 : tensor<3x2xbf16>
}

0 comments on commit 9489ae8

Please sign in to comment.