diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index 050f8ca3f32aeda..6575b39fd45a1f8 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -132,14 +132,17 @@ bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) { return inputOp.hasOneUse(); } -template -DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, +template +DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType, ShapedType outputType, llvm::ArrayRef permValues) { + using ElementType = std::decay_t; + + assert(inputType.getElementType() == outputType.getElementType()); + if (inputType.getNumElements() == 0) - return DenseElementsAttr::get(outputType, llvm::ArrayRef{}); + return DenseElementsAttr::get(outputType, llvm::ArrayRef{}); - auto attrValues = attr.getValues(); auto inputShape = inputType.getShape(); // The inverted permutation map and strides of the output are used to compute @@ -148,10 +151,11 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, auto outputStrides = computeStrides(outputType.getShape()); auto invertedPermValues = invertPermutationVector(permValues); - auto initialValue = *std::begin(attrValues); - SmallVector outputValues(inputType.getNumElements(), initialValue); + auto initialValue = *std::begin(data); + SmallVector outputValues(inputType.getNumElements(), + initialValue); - for (const auto &it : llvm::enumerate(attrValues)) { + for (const auto &it : llvm::enumerate(data)) { auto srcLinearIndex = it.index(); uint64_t dstLinearIndex = 0; @@ -170,7 +174,7 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, } return DenseElementsAttr::get(outputType, - llvm::ArrayRef(outputValues)); + llvm::ArrayRef(outputValues)); } // A type specialized transposition of an ElementsAttr. @@ -180,32 +184,28 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType, ShapedType outputType, llvm::ArrayRef permValues) { - auto baseType = inputType.getElementType(); - - // Handle possible integer types - if (auto intType = dyn_cast(baseType)) { - switch (intType.getWidth()) { - case 1: - return transposeType(attr, inputType, outputType, permValues); - case 8: - return transposeType(attr, inputType, outputType, permValues); - case 16: - return transposeType(attr, inputType, outputType, permValues); - case 32: - return transposeType(attr, inputType, outputType, permValues); - case 64: - return transposeType(attr, inputType, outputType, permValues); - default: - return transposeType(attr, inputType, outputType, permValues); - } - } + if (auto data = attr.tryGetValues()) + return transposeType(*data, inputType, outputType, permValues); - // Handle possible float types - if (baseType.isF32()) { - return transposeType(attr, inputType, outputType, permValues); - } + if (auto data = attr.tryGetValues()) + return transposeType(*data, inputType, outputType, permValues); + + if (auto data = attr.tryGetValues()) + return transposeType(*data, inputType, outputType, permValues); + + if (auto data = attr.tryGetValues()) + return transposeType(*data, inputType, outputType, permValues); - return transposeType(attr, inputType, outputType, permValues); + if (auto data = attr.tryGetValues()) + return transposeType(*data, inputType, outputType, permValues); + + if (auto data = attr.tryGetValues()) + return transposeType(*data, inputType, outputType, permValues); + + if (auto data = attr.tryGetValues()) + return transposeType(*data, inputType, outputType, permValues); + + return nullptr; } struct TosaFoldConstantTranspose : public OpRewritePattern { @@ -228,14 +228,19 @@ struct TosaFoldConstantTranspose : public OpRewritePattern { DenseIntElementsAttr permAttr; if (!matchPattern(op.getPerms(), m_Constant(&permAttr))) return failure(); - auto permValues = llvm::to_vector<6>(llvm::map_range( + auto permValues = llvm::map_to_vector( // TOSA allows both 32- and 64-bit integer tensors here. permAttr.getValues(), - [](const APInt &val) { return val.getSExtValue(); })); + [](const APInt &val) { return val.getSExtValue(); }); auto inputType = cast(op.getInput1().getType()); auto resultAttr = transpose(inputValues, inputType, outputType, permValues); + if (!resultAttr) { + return rewriter.notifyMatchFailure( + op, "unsupported attribute or element type"); + } + rewriter.replaceOpWithNewOp(op, outputType, resultAttr); return success(); } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 27ca3ae3c21be68..de752f31fcbaa1e 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -112,6 +112,23 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> } +// CHECK-LABEL: @transpose_nofold_dense_resource +func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> { + %0 = "tosa.const"() <{value = dense_resource : tensor<2x2xf32>}> : () -> tensor<2x2xf32> + %1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> + + // CHECK: tosa.transpose + %2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + return %2 : tensor<2x2xf32> +} +{-# + dialect_resources: { + builtin: { + resource: "0x08000000010000000000000002000000000000000300000000000000" + } + } +#-} + // ----- // CHECK-LABEL: @fold_add_zero_rhs_f32