diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index daa2ac0e3085ecc..c5a4c16604a308f 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -18,6 +18,7 @@ #include "mlir/Pass/Pass.h" namespace mlir { +class TypeConverter; namespace tosa { #define GEN_PASS_DECL @@ -37,6 +38,8 @@ void populateTosaConstantReduction(MLIRContext *ctx, RewritePatternSet &patterns, bool aggressiveReduceConstant); +void populateTosaTypeConversion(TypeConverter &converter); + std::unique_ptr createTosaLayerwiseConstantFoldPass(); std::unique_ptr createTosaLayerwiseConstantFoldPass( const TosaLayerwiseConstantFoldPassOptions &options); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 95eca8d61115311..c0c015ab34aab0a 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -224,7 +224,8 @@ class ReshapeConverter : public OpConversionPattern { matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = reshape.getLoc(); - auto resultType = cast_if_present(getTypeConverter()->convertType(reshape.getType())); + auto resultType = cast_if_present( + getTypeConverter()->convertType(reshape.getType())); if (!resultType) { return rewriter.notifyMatchFailure(reshape.getLoc(), "could not convert result type"); @@ -296,12 +297,13 @@ class SliceConverter : public OpConversionPattern { } }; -class PadConverter : public OpRewritePattern { +class PadConverter : public OpConversionPattern { public: - using OpRewritePattern::OpRewritePattern; + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(tosa::PadOp padOp, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { auto loc = padOp.getLoc(); auto input = padOp.getInput1(); auto padding = padOp.getPadding(); @@ -437,10 +439,7 @@ struct ConcatConverter : public OpConversionPattern { void mlir::tosa::populateTosaToTensorConversionPatterns( TypeConverter &converter, RewritePatternSet *patterns) { - patterns->add< - ConcatConverter, - PadConverter, - SliceConverter - >(patterns->getContext()); - patterns->add(converter, patterns->getContext()); + patterns + ->add( + converter, patterns->getContext()); } diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp index 31049777323be8b..23a45e718e8871b 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp @@ -45,7 +45,7 @@ struct TosaToTensor : public impl::TosaToTensorBase { target.addLegalDialect(); TypeConverter converter; - mlir::tosa::populateTosaToLinalgTypeConversion(converter); + mlir::tosa::populateTosaTypeConversion(converter); mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 0e6510ba1e92554..c78a74b874aff1c 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaLayerwiseConstantFoldPass.cpp TosaMakeBroadcastable.cpp TosaOptionalDecompositions.cpp + TosaTypeConverters.cpp TosaValidation.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp new file mode 100644 index 000000000000000..d2650de8cd7f020 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp @@ -0,0 +1,52 @@ + +//===- TosaTypeConverters.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Type converters for lowering TOSA to linalg/arith. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +void mlir::tosa::populateTosaTypeConversion(TypeConverter &converter) { + converter.addConversion([&](Type type) -> std::optional { + if (type.isUnsignedInteger()) { + return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(), + IntegerType::SignednessSemantics::Signless); + } + return type; + }); + converter.addConversion([&](TensorType type) -> std::optional { + auto converted = converter.convertType(type.getElementType()); + if (!converted) + return {}; + return type.clone(converted); + }); + converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> std::optional { + if (inputs.size() != 1) + return std::nullopt; + + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> std::optional { + if (inputs.size() != 1) + return std::nullopt; + + return builder.create(loc, resultType, inputs) + .getResult(0); + }); +} diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir index 46d21b8f3bbe7b4..9c0e8108a422d21 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -445,6 +445,20 @@ func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) -> // ----- +// CHECK-LABEL: @test_reshape_samerank_unsigned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>) +func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> { + // CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8> + // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8> + // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8> + // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8 + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2xui8>) -> tensor<2x3xui8> + // CHECK-NEXT: return %[[CAST2]] + return %0 : tensor<2x3xui8> +} + +// ----- + // CHECK-LABEL: func @slice func.func @slice(%arg0: tensor<6xf32>) ->() { // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]