Skip to content

Commit

Permalink
TosaToArith: Use type converter for tosa.const (#195)
Browse files Browse the repository at this point in the history
* TosaToArith: Use type converter for tosa.const

Co-authored-by: Tina Jung <tinamaria.jung@amd.com>
  • Loading branch information
mgehre-amd and TinaAMD authored Jun 12, 2024
1 parent 9db9dd7 commit f77637f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 14 deletions.
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Pass/Pass.h"

namespace mlir {
class TypeConverter;

#define GEN_PASS_DECL_TOSATOARITH
#include "mlir/Conversion/Passes.h.inc"
Expand All @@ -25,7 +26,8 @@ namespace tosa {
std::unique_ptr<Pass> createTosaToArith(bool includeApplyRescale = false,
bool use32BitApplyRescale = false);

void populateTosaToArithConversionPatterns(RewritePatternSet *patterns);
void populateTosaToArithConversionPatterns(TypeConverter &converter,
RewritePatternSet *patterns);

void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns,
bool include32Bit = false);
Expand Down
25 changes: 18 additions & 7 deletions mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,31 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace tosa;

namespace {

class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
class ConstOpConverter : public OpConversionPattern<tosa::ConstOp> {
public:
using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(tosa::ConstOp op,
PatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
LogicalResult matchAndRewrite(tosa::ConstOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {

auto elements = dyn_cast<DenseElementsAttr>(adaptor.getValue());
if (!elements) {
return rewriter.notifyMatchFailure(op, "expected dense elements attr");
}

auto convertedElTy = getTypeConverter()->convertType(elements.getElementType());
if (!convertedElTy) {
return rewriter.notifyMatchFailure(op, "type conversion failed");
}
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, elements.bitcast(convertedElTy));
return success();
}
};
Expand Down Expand Up @@ -238,9 +249,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {

} // namespace

void mlir::tosa::populateTosaToArithConversionPatterns(
void mlir::tosa::populateTosaToArithConversionPatterns(TypeConverter &converter,
RewritePatternSet *patterns) {
patterns->add<ConstOpConverter>(patterns->getContext());
patterns->add<ConstOpConverter>(converter, patterns->getContext());
}

void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
Expand Down
6 changes: 5 additions & 1 deletion mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"

namespace mlir {
#define GEN_PASS_DEF_TOSATOARITH
Expand All @@ -34,12 +35,15 @@ struct TosaToArith : public impl::TosaToArithBase<TosaToArith> {
TosaToArith(TosaToArithOptions &options) : TosaToArithBase(options) {}

void runOnOperation() override {
TypeConverter converter;
mlir::tosa::populateTosaToLinalgTypeConversion(converter);

RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addIllegalOp<tosa::ConstOp>();
target.addLegalDialect<arith::ArithDialect>();

mlir::tosa::populateTosaToArithConversionPatterns(&patterns);
mlir::tosa::populateTosaToArithConversionPatterns(converter, &patterns);

if (this->includeApplyRescale) {
mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns,
Expand Down
14 changes: 9 additions & 5 deletions mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=false" %s -verify-diagnostics -o -| FileCheck --check-prefix="SCALE" %s

// CHECK-LABEL: func @const_test
func.func @const_test() -> (tensor<i32>) {
// CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor<i32>
%result = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
func.func @const_test() -> (tensor<i32>, tensor<ui32>) {
// CHECK: %[[CI32:.+]] = arith.constant dense<3> : tensor<i32>
%i32 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>

// CHECK: return [[C3]]
return %result : tensor<i32>
// CHECK: %[[CUI32:.+]] = arith.constant dense<3> : tensor<i32>
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CUI32]] : tensor<i32> to tensor<ui32>
%ui32 = "tosa.const"() {value = dense<3> : tensor<ui32>} : () -> tensor<ui32>

// CHECK: return %[[CI32]], %[[CAST]]
return %i32, %ui32 : tensor<i32>, tensor<ui32>
}

// -----
Expand Down

0 comments on commit f77637f

Please sign in to comment.