Skip to content

Commit

Permalink
Merge pull request #124 from Xilinx/tina.FXML-4224-add-arith-to-emitc
Browse files Browse the repository at this point in the history
[MLIR][EmitC] Add arith to emitc conversion
  • Loading branch information
mgehre-amd authored Mar 1, 2024
2 parents 0e7d4e9 + c7d2689 commit 2c8cec1
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 0 deletions.
22 changes: 22 additions & 0 deletions mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- ArithToEmitC.h - Convert Arith to EmitC ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H

#include "mlir/Pass/Pass.h"

namespace mlir {
class RewritePatternSet;

#define GEN_PASS_DECL_ARITHTOEMITCCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"

void populateArithToEmitCConversionPatterns(RewritePatternSet &patterns);
} // namespace mlir

#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
}

//===----------------------------------------------------------------------===//
// ArithToEmitC
//===----------------------------------------------------------------------===//

def ArithToEmitCConversionPass : Pass<"convert-arith-to-emitc"> {
let summary = "Convert Arith ops to EmitC ops";
let description = [{
Convert `arith` operations to operations in the `emitc` dialect.
}];
let dependentDialects = ["emitc::EmitCDialect"];
}

//===----------------------------------------------------------------------===//
// ArithToLLVM
//===----------------------------------------------------------------------===//
Expand Down
104 changes: 104 additions & 0 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//===- ArithToEmitC.cpp - Arith to EmitC conversion -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert arith ops into emitc ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DEF_ARITHTOEMITCCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {

static bool isConvertibleToEmitC(Type type) {
Type baseType = type;
if (auto tensorType = dyn_cast<TensorType>(type)) {
if (!tensorType.hasRank() || !tensorType.hasStaticShape()) {
return false;
}
baseType = tensorType.getElementType();
}

if (isa<IndexType>(baseType)) {
return true;
}

if (auto intType = dyn_cast<IntegerType>(baseType)) {
switch (intType.getWidth()) {
case 1:
case 8:
case 16:
case 32:
case 64:
return true;
}
return false;
}

if (auto floatType = dyn_cast<FloatType>(baseType)) {
return floatType.isF32() || floatType.isF64();
}

return false;
}

class ArithConstantOpConversionPattern
: public OpRewritePattern<arith::ConstantOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::ConstantOp arithConst,
PatternRewriter &rewriter) const override {

auto constantType = arithConst.getType();
if (!isConvertibleToEmitC(constantType)) {
return rewriter.notifyMatchFailure(arithConst.getLoc(),
"Type cannot be converted to emitc");
}

rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, constantType,
arithConst.getValue());
return success();
}
};

struct ConvertArithToEmitCPass
: public impl::ArithToEmitCConversionPassBase<ConvertArithToEmitCPass> {
public:
void runOnOperation() override {

ConversionTarget target(getContext());
target.addIllegalDialect<arith::ArithDialect>();
target.addLegalDialect<emitc::EmitCDialect>();
RewritePatternSet patterns(&getContext());
populateArithToEmitCConversionPatterns(patterns);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
}
}
};

} // namespace

void mlir::populateArithToEmitCConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ArithConstantOpConversionPattern>(patterns.getContext());
}
17 changes: 17 additions & 0 deletions mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_conversion_library(ArithToEmitC
ArithToEmitC.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIREmitCDialect
MLIRArithDialect
MLIRTransforms
)
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s

func.func @arith_constant_complex_tensor() -> (tensor<complex<i32>>) {
// expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
%c = arith.constant dense<(2, 2)> : tensor<complex<i32>>
return %c : tensor<complex<i32>>
}

// -----

func.func @arith_constant_invalid_int_type() -> (i10) {
// expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
%c = arith.constant 0 : i10
return %c : i10
}
21 changes: 21 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: mlir-opt -split-input-file -convert-arith-to-emitc %s | FileCheck %s

// CHECK-LABEL: arith_constants
func.func @arith_constants() {
// CHECK: emitc.constant
// CHECK-SAME: value = 0 : index
%c_index = arith.constant 0 : index
// CHECK: emitc.constant
// CHECK-SAME: value = 0 : i32
%c_signless_int_32 = arith.constant 0 : i32
// CHECK: emitc.constant
// CHECK-SAME: value = 0.{{0+}}e+00 : f32
%c_float_32 = arith.constant 0.0 : f32
// CHECK: emitc.constant
// CHECK-SAME: value = dense<0> : tensor<i32>
%c_tensor_single_value = arith.constant dense<0> : tensor<i32>
// CHECK: emitc.constant
// CHECK-SAME: value{{.*}}[1, 2], [-3, 9], [0, 0], [2, -1]{{.*}}tensor<4x2xi64>
%c_tensor_value = arith.constant dense<[[1, 2], [-3, 9], [0, 0], [2, -1]]> : tensor<4x2xi64>
return
}

0 comments on commit 2c8cec1

Please sign in to comment.