From aaae150035f429a4ddae10fe12fc68a530bd1c89 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 12 Jun 2024 11:22:12 +0200 Subject: [PATCH] TosaToArith: Use type converter for tosa.const --- .../mlir/Conversion/TosaToArith/TosaToArith.h | 4 ++- .../Conversion/TosaToArith/TosaToArith.cpp | 25 +++++++++++++------ .../TosaToArith/TosaToArithPass.cpp | 6 ++++- .../Conversion/TosaToArith/tosa-to-arith.mlir | 14 +++++++---- 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h index e7158ee3852e181..1d651e394b897d1 100644 --- a/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h +++ b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h @@ -16,6 +16,7 @@ #include "mlir/Pass/Pass.h" namespace mlir { +class TypeConverter; #define GEN_PASS_DECL_TOSATOARITH #include "mlir/Conversion/Passes.h.inc" @@ -25,7 +26,8 @@ namespace tosa { std::unique_ptr createTosaToArith(bool includeApplyRescale = false, bool use32BitApplyRescale = false); -void populateTosaToArithConversionPatterns(RewritePatternSet *patterns); +void populateTosaToArithConversionPatterns(TypeConverter &converter, + RewritePatternSet *patterns); void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit = false); diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp index 50e57682a2dc8da..e1c5841a656f937 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -16,19 +16,30 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include using namespace mlir; using namespace tosa; namespace { -class ConstOpConverter : public OpRewritePattern { +class ConstOpConverter : public OpConversionPattern { public: - using OpRewritePattern::OpRewritePattern; + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(tosa::ConstOp op, - PatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp(op, op.getValue()); + LogicalResult matchAndRewrite(tosa::ConstOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + + auto elements = dyn_cast(adaptor.getValue()); + if (!elements) { + return rewriter.notifyMatchFailure(op, "expected dense elements attr"); + } + + auto convertedElTy = getTypeConverter()->convertType(elements.getElementType()); + if (!convertedElTy) { + return rewriter.notifyMatchFailure(op, "type conversion failed"); + } + rewriter.replaceOpWithNewOp(op, elements.bitcast(convertedElTy)); return success(); } }; @@ -238,9 +249,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern { } // namespace -void mlir::tosa::populateTosaToArithConversionPatterns( +void mlir::tosa::populateTosaToArithConversionPatterns(TypeConverter &converter, RewritePatternSet *patterns) { - patterns->add(patterns->getContext()); + patterns->add(converter, patterns->getContext()); } void mlir::tosa::populateTosaRescaleToArithConversionPatterns( diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp index de82c0335c985de..0dd24c31d94743f 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp @@ -19,6 +19,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include namespace mlir { #define GEN_PASS_DEF_TOSATOARITH @@ -34,12 +35,15 @@ struct TosaToArith : public impl::TosaToArithBase { TosaToArith(TosaToArithOptions &options) : TosaToArithBase(options) {} void runOnOperation() override { + TypeConverter converter; + mlir::tosa::populateTosaToLinalgTypeConversion(converter); + RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addIllegalOp(); target.addLegalDialect(); - mlir::tosa::populateTosaToArithConversionPatterns(&patterns); + mlir::tosa::populateTosaToArithConversionPatterns(converter, &patterns); if (this->includeApplyRescale) { mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns, diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir index c4f82d53af98224..63d1423ea3ad6df 100644 --- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir +++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir @@ -2,12 +2,16 @@ // RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=false" %s -verify-diagnostics -o -| FileCheck --check-prefix="SCALE" %s // CHECK-LABEL: func @const_test -func.func @const_test() -> (tensor) { - // CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor - %result = "tosa.const"() {value = dense<3> : tensor} : () -> tensor +func.func @const_test() -> (tensor, tensor) { + // CHECK: %[[CI32:.+]] = arith.constant dense<3> : tensor + %i32 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor - // CHECK: return [[C3]] - return %result : tensor + // CHECK: %[[CUI32:.+]] = arith.constant dense<3> : tensor + // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CUI32]] : tensor to tensor + %ui32 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + + // CHECK: return %[[CI32]], %[[CAST]] + return %i32, %ui32 : tensor, tensor } // -----