Skip to content

Commit

Permalink
Merge with fixes of fa6e433
Browse files Browse the repository at this point in the history
  • Loading branch information
cferry-AMD committed Aug 16, 2024
2 parents a074d69 + fa6e433 commit 030442b
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 72 deletions.
158 changes: 86 additions & 72 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ DenseElementsAttr applyElementWise(
return newTensor;
}

/// Function that checks if the type contained in \p toCheck is float.
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
PatternRewriter &rewriter) {
if (isa<FloatType>(toCheck.getType().getElementType())) {
return success();
}
return rewriter.notifyMatchFailure(location,
"Unexpected input tensor type: the "
"TOSA spec only allows floats");
}

/// Function that checks if \p toCheck is a dense TOSA constant tensor.
LogicalResult notifyIfNoTosaDenseConstantTensor(Value toCheck,
TosaOp location,
Expand All @@ -191,10 +202,50 @@ LogicalResult notifyIfNoTosaDenseConstantTensor(Value toCheck,
"it operates on a TOSA constant");
}

template <typename BaseType, typename RangeT>
void transposeArray(RangeT inputValues, ShapedType inputType,
SmallVector<BaseType> &outputValues, ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &rewriter) {
auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
if (failed(floatCheck)) {
return floatCheck;
}
return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
}

/// Heuristic to decide when to replace a unary operation on a constant with the
/// folded value.
/// Folding operations on constants can lead to an increased memory usage
/// whenever the input cannot be replaced but a new constant is inserted. Hence,
/// this will currently only suggest folding when the memory impact is
/// negligible.
/// Takes the \p unaryOp and the constant input \p values.
/// \returns Whether folding should be applied.
bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
assert(unaryOp->getNumOperands() == 1);
auto inputOp = unaryOp->getOperand(0);

// If the input is a splat, we don't care for the number of users
if (isa<SplatElementsAttr>(values)) {
return true;
}

// If this is the only use of the tensor it should be replaced as no
// additional memory is required
return inputOp.hasOneUse();
}

template <typename RangeType>
DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
using ElementType = std::decay_t<decltype(*std::begin(data))>;

assert(inputType.getElementType() == outputType.getElementType());

if (inputType.getNumElements() == 0)
return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});

auto inputShape = inputType.getShape();

// The inverted permutation map and strides of the output are used to compute
Expand All @@ -203,7 +254,11 @@ void transposeArray(RangeT inputValues, ShapedType inputType,
auto outputStrides = computeStrides(outputType.getShape());
auto invertedPermValues = invertPermutationVector(permValues);

for (auto it : llvm::enumerate(inputValues)) {
auto initialValue = *std::begin(data);
SmallVector<ElementType> outputValues(inputType.getNumElements(),
initialValue);

for (const auto &it : llvm::enumerate(data)) {
auto srcLinearIndex = it.index();
uint64_t dstLinearIndex = 0;

Expand All @@ -220,86 +275,40 @@ void transposeArray(RangeT inputValues, ShapedType inputType,

outputValues[dstLinearIndex] = it.value();
}
}

template <typename BaseType>
DenseElementsAttr transposeTypeRaw(DenseElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
ArrayRef<BaseType> inputValues =
cast<DenseIntOrFPElementsAttr>(attr).getNonSplatRawData<BaseType>();

SmallVector<BaseType> outputValues;
outputValues.resize_for_overwrite(inputType.getNumElements());
transposeArray<BaseType>(inputValues, inputType, /*out*/ outputValues,
outputType, permValues);

ArrayRef rawOutputValues(reinterpret_cast<const char *>(outputValues.data()),
outputValues.size() * sizeof(BaseType));
return DenseElementsAttr::getFromRawBuffer(outputType, rawOutputValues);
}

template <typename BaseType>
DenseElementsAttr transposeType(DenseElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {

auto inputValues = attr.getValues<BaseType>();
SmallVector<BaseType> outputValues(inputType.getNumElements(),
*std::begin(inputValues));
transposeArray<BaseType>(inputValues, inputType, /*out*/ outputValues,
outputType, permValues);
return DenseElementsAttr::get(outputType,
llvm::ArrayRef<BaseType>(outputValues));
llvm::ArrayRef<ElementType>(outputValues));
}

// A type specialized transposition of an ElementsAttr.
// This implementation tries to operate on the underlying data in its raw
// representation when possible to avoid allocating a large number of Attribute
// objects.
DenseElementsAttr transpose(DenseElementsAttr attr, ShapedType inputType,
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
if (auto data = attr.tryGetValues<bool>())
return transposeType(*data, inputType, outputType, permValues);

assert(outputType.getNumElements() == inputType.getNumElements());
assert(outputType.getElementType() == inputType.getElementType());
if (auto data = attr.tryGetValues<int8_t>())
return transposeType(*data, inputType, outputType, permValues);

auto baseType = inputType.getElementType();
if (auto data = attr.tryGetValues<int16_t>())
return transposeType(*data, inputType, outputType, permValues);

// Handle possible integer types
if (auto intType = dyn_cast<IntegerType>(baseType)) {
switch (intType.getWidth()) {
case 1:
// i1 has special alignment which is not handled by transposeTypeRaw.
return transposeType<bool>(attr, inputType, outputType, permValues);
case 8:
return transposeTypeRaw<uint8_t>(attr, inputType, outputType, permValues);
case 16:
return transposeTypeRaw<uint16_t>(attr, inputType, outputType,
permValues);
case 32:
return transposeTypeRaw<uint32_t>(attr, inputType, outputType,
permValues);
case 64:
return transposeTypeRaw<uint64_t>(attr, inputType, outputType,
permValues);
default:
return transposeType<APInt>(attr, inputType, outputType, permValues);
}
}
if (auto data = attr.tryGetValues<int32_t>())
return transposeType(*data, inputType, outputType, permValues);

// Handle possible float types
if (baseType.isF32()) {
return transposeTypeRaw<uint32_t>(attr, inputType, outputType, permValues);
}
if (baseType.isF64()) {
return transposeTypeRaw<uint64_t>(attr, inputType, outputType, permValues);
}
if (baseType.isBF16()) {
return transposeTypeRaw<uint16_t>(attr, inputType, outputType, permValues);
}
if (auto data = attr.tryGetValues<int64_t>())
return transposeType(*data, inputType, outputType, permValues);

if (auto data = attr.tryGetValues<float>())
return transposeType(*data, inputType, outputType, permValues);

if (auto data = attr.tryGetValues<APFloat>())
return transposeType(*data, inputType, outputType, permValues);

return transposeType<APFloat>(attr, inputType, outputType, permValues);
return nullptr;
}

template<typename TosaOp>
Expand Down Expand Up @@ -553,14 +562,19 @@ struct TosaFoldConstantTranspose : public TosaFoldConstantBase<tosa::TransposeOp
DenseIntElementsAttr permAttr;
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
return failure();
auto permValues = llvm::to_vector<6>(llvm::map_range(
auto permValues = llvm::map_to_vector(
// TOSA allows both 32- and 64-bit integer tensors here.
permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
[](const APInt &val) { return val.getSExtValue(); });

auto inputType = cast<ShapedType>(op.getInput1().getType());

auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
if (!resultAttr) {
return rewriter.notifyMatchFailure(
op, "unsupported attribute or element type");
}

rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
return success();
}
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Dialect/Tosa/constant-op-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,23 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i
return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
}

// CHECK-LABEL: @transpose_nofold_dense_resource
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
%0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
%1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>

// CHECK: tosa.transpose
%2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
}
{-#
dialect_resources: {
builtin: {
resource: "0x08000000010000000000000002000000000000000300000000000000"
}
}
#-}

// -----

// CHECK-LABEL: @fold_add_zero_rhs_f32
Expand Down

0 comments on commit 030442b

Please sign in to comment.