Skip to content

Commit

Permalink
[mlir][tosa] Fix assertion failure in tosa-layerwise-constant-fold (l…
Browse files Browse the repository at this point in the history
…lvm#85670)

The existing implementation of tosa-layerwise-constant-fold only works
for constant values backed by DenseElementsAttr. For constants which
hold DenseResourceAttrs, the folder will end up asserting at runtime, as
it assumes that the backing data can always be accessed through
ElementsAttr::getValues.

This change reworks the logic so that types types used to perform
folding are based on whether the ElementsAttr can be converted to a
range of that particular type.

---------

Co-authored-by: Spenser Bauman <sabauma@mathworks.com>
Co-authored-by: Tina Jung <tinamaria.jung@amd.com>
  • Loading branch information
3 people authored Mar 21, 2024
1 parent 7340263 commit fa6e433
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 34 deletions.
73 changes: 39 additions & 34 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,17 @@ bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
return inputOp.hasOneUse();
}

template <typename BaseType>
DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
template <typename RangeType>
DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
using ElementType = std::decay_t<decltype(*std::begin(data))>;

assert(inputType.getElementType() == outputType.getElementType());

if (inputType.getNumElements() == 0)
return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});

auto attrValues = attr.getValues<BaseType>();
auto inputShape = inputType.getShape();

// The inverted permutation map and strides of the output are used to compute
Expand All @@ -148,10 +151,11 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
auto outputStrides = computeStrides(outputType.getShape());
auto invertedPermValues = invertPermutationVector(permValues);

auto initialValue = *std::begin(attrValues);
SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
auto initialValue = *std::begin(data);
SmallVector<ElementType> outputValues(inputType.getNumElements(),
initialValue);

for (const auto &it : llvm::enumerate(attrValues)) {
for (const auto &it : llvm::enumerate(data)) {
auto srcLinearIndex = it.index();

uint64_t dstLinearIndex = 0;
Expand All @@ -170,7 +174,7 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
}

return DenseElementsAttr::get(outputType,
llvm::ArrayRef<BaseType>(outputValues));
llvm::ArrayRef<ElementType>(outputValues));
}

// A type specialized transposition of an ElementsAttr.
Expand All @@ -180,32 +184,28 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
auto baseType = inputType.getElementType();

// Handle possible integer types
if (auto intType = dyn_cast<IntegerType>(baseType)) {
switch (intType.getWidth()) {
case 1:
return transposeType<bool>(attr, inputType, outputType, permValues);
case 8:
return transposeType<int8_t>(attr, inputType, outputType, permValues);
case 16:
return transposeType<int16_t>(attr, inputType, outputType, permValues);
case 32:
return transposeType<int32_t>(attr, inputType, outputType, permValues);
case 64:
return transposeType<int64_t>(attr, inputType, outputType, permValues);
default:
return transposeType<APInt>(attr, inputType, outputType, permValues);
}
}
if (auto data = attr.tryGetValues<bool>())
return transposeType(*data, inputType, outputType, permValues);

// Handle possible float types
if (baseType.isF32()) {
return transposeType<float>(attr, inputType, outputType, permValues);
}
if (auto data = attr.tryGetValues<int8_t>())
return transposeType(*data, inputType, outputType, permValues);

if (auto data = attr.tryGetValues<int16_t>())
return transposeType(*data, inputType, outputType, permValues);

if (auto data = attr.tryGetValues<int32_t>())
return transposeType(*data, inputType, outputType, permValues);

return transposeType<APFloat>(attr, inputType, outputType, permValues);
if (auto data = attr.tryGetValues<int64_t>())
return transposeType(*data, inputType, outputType, permValues);

if (auto data = attr.tryGetValues<float>())
return transposeType(*data, inputType, outputType, permValues);

if (auto data = attr.tryGetValues<APFloat>())
return transposeType(*data, inputType, outputType, permValues);

return nullptr;
}

struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
Expand All @@ -228,14 +228,19 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
DenseIntElementsAttr permAttr;
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
return failure();
auto permValues = llvm::to_vector<6>(llvm::map_range(
auto permValues = llvm::map_to_vector(
// TOSA allows both 32- and 64-bit integer tensors here.
permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
[](const APInt &val) { return val.getSExtValue(); });

auto inputType = cast<ShapedType>(op.getInput1().getType());

auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
if (!resultAttr) {
return rewriter.notifyMatchFailure(
op, "unsupported attribute or element type");
}

rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
return success();
}
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Dialect/Tosa/constant-op-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i
return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
}

// CHECK-LABEL: @transpose_nofold_dense_resource
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
%0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
%1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>

// CHECK: tosa.transpose
%2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
}
{-#
dialect_resources: {
builtin: {
resource: "0x08000000010000000000000002000000000000000300000000000000"
}
}
#-}

// -----

// CHECK-LABEL: @fold_add_zero_rhs_f32
Expand Down

0 comments on commit fa6e433

Please sign in to comment.