Skip to content

Commit

Permalink
[FXML-4440] Introduce a pass that annotates the type of the argument …
Browse files Browse the repository at this point in the history
…as an attribute (#167)

Create a pass that annotates all func op inputs with their type as attributes.
  • Loading branch information
josel-amd authored Apr 22, 2024
1 parent 29e1ec6 commit f799862
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/Func/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,15 @@ def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination",
let constructor = "mlir::func::createDuplicateFunctionEliminationPass()";
}

def AnnotateFunctionType: Pass<"annotate-function-type", "func::FuncOp"> {
let summary = "Annotate the function type as type attributes";
let description = [{
Annotates all the inputs and outputs of func.func operators with a type
attribute. The type attribute mirrors the actual type of the inputs/outputs.

This pass can be used to trace back the original types of func.func
operators in case they need to be modified.
}];
}

#endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD
47 changes: 47 additions & 0 deletions mlir/lib/Dialect/Func/Transforms/AnnotateFunctionType.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//===- AnnotateInputTypes.cpp - Type attribute annotation for func ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass that creates type attributes for func parameters,
// that mirror the actual type. This is useful when the func op input types
// might change.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/Transforms/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace mlir::func {
#define GEN_PASS_DEF_ANNOTATEFUNCTIONTYPE
#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
} // namespace mlir::func

namespace {
struct AnnotateFunctionTypePass
: public mlir::func::impl::AnnotateFunctionTypeBase<
AnnotateFunctionTypePass> {

void runOnOperation() override {
func::FuncOp func = getOperation();
auto inputs = func.getArgumentTypes();
auto results = func.getResultTypes();

for (const auto [argNum, type] : llvm::enumerate(inputs)) {
func.setArgAttr(argNum, "func.orig_type", TypeAttr::get(type));
}

for (const auto [resultNum, type] : llvm::enumerate(results)) {
func.setResultAttr(resultNum, "func.orig_type", TypeAttr::get(type));
}
}
};
} // namespace
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRFuncTransforms
AnnotateFunctionType.cpp
DecomposeCallGraphTypes.cpp
DuplicateFunctionElimination.cpp
FuncBufferize.cpp
Expand Down
27 changes: 27 additions & 0 deletions mlir/test/Dialect/Func/annotate-types.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: mlir-opt %s --split-input-file --annotate-function-type | FileCheck %s

// CHECK-LABEL: func.func @one_arg(%arg0: tensor<f32> {func.orig_type = tensor<f32>}) -> (tensor<f32> {func.orig_type = tensor<f32>}) {
func.func @one_arg(%arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}

// -----

// CHECK-LABEL: func.func @one_arg_int(%arg0: tensor<ui8> {func.orig_type = tensor<ui8>}) -> (tensor<ui8> {func.orig_type = tensor<ui8>}) {
func.func @one_arg_int(%arg0: tensor<ui8>) -> tensor<ui8> {
return %arg0 : tensor<ui8>
}

// -----

// CHECK-LABEL: func.func @n_rank_tensor(%arg0: tensor<3x4x5xf32> {func.orig_type = tensor<3x4x5xf32>}) -> (tensor<3x4x5xf32> {func.orig_type = tensor<3x4x5xf32>}) {
func.func @n_rank_tensor(%arg0: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
return %arg0 : tensor<3x4x5xf32>
}

// -----

// CHECK-LABEL: func.func @two_args(%arg0: f32 {func.orig_type = f32}, %arg1: f32 {func.orig_type = f32}) -> (f32 {func.orig_type = f32}) {
func.func @two_args(%arg0: f32, %arg1: f32) -> f32 {
return %arg0 : f32
}

0 comments on commit f799862

Please sign in to comment.