Skip to content

Commit

Permalink
[MLIR] Add idempotent trait folding
Browse files Browse the repository at this point in the history
This trait simply adds a fold of f(f(x)) = f(x) when an operation is labelled as idempotent

Reviewed By: rriddle, andyly

Differential Revision: https://reviews.llvm.org/D89421
  • Loading branch information
ahmedsabie authored and andyly committed Oct 16, 2020
1 parent 0a7cd99 commit 7dff6b8
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 1 deletion.
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,8 @@ def ResultsBroadcastableShape :
NativeOpTrait<"ResultsBroadcastableShape">;
// X op Y == Y op X
def Commutative : NativeOpTrait<"IsCommutative">;
// op op X == op X
def Idempotent : NativeOpTrait<"IsIdempotent">;
// op op X == X
def Involution : NativeOpTrait<"IsInvolution">;
// Op behaves like a constant.
Expand Down
26 changes: 25 additions & 1 deletion mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,12 @@ namespace OpTrait {
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
OpFoldResult foldIdempotent(Operation *op);
OpFoldResult foldInvolution(Operation *op);
LogicalResult verifyZeroOperands(Operation *op);
LogicalResult verifyOneOperand(Operation *op);
LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyIsIdempotent(Operation *op);
LogicalResult verifyIsInvolution(Operation *op);
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyOperandsAreFloatLike(Operation *op);
Expand Down Expand Up @@ -1012,7 +1014,7 @@ class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
};

/// This class adds property that the operation is an involution.
/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
template <typename ConcreteType>
class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
public:
Expand All @@ -1033,6 +1035,28 @@ class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
}
};

/// This class adds property that the operation is idempotent.
/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
template <typename ConcreteType>
class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(ConcreteType::template hasTrait<OneResult>(),
"expected operation to produce one result");
static_assert(ConcreteType::template hasTrait<OneOperand>(),
"expected operation to take one operand");
static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
"expected operation to preserve type");
// Idempotent requires the operation to be side effect free as well
// but currently this check is under a FIXME and is not actually done.
return impl::verifyIsIdempotent(op);
}

static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
return impl::foldIdempotent(op);
}
};

/// This class verifies that all operands of the specified op have a float type,
/// a vector thereof, or a tensor thereof.
template <typename ConcreteType>
Expand Down
18 changes: 18 additions & 0 deletions mlir/lib/IR/Operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,16 @@ InFlightDiagnostic OpState::emitRemark(const Twine &message) {
// Op Trait implementations
//===----------------------------------------------------------------------===//

OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
auto *argumentOp = op->getOperand(0).getDefiningOp();
if (argumentOp && op->getName() == argumentOp->getName()) {
// Replace the outer operation output with the inner operation.
return op->getOperand(0);
}

return {};
}

OpFoldResult OpTrait::impl::foldInvolution(Operation *op) {
auto *argumentOp = op->getOperand(0).getDefiningOp();
if (argumentOp && op->getName() == argumentOp->getName()) {
Expand Down Expand Up @@ -730,6 +740,14 @@ static Type getTensorOrVectorElementType(Type type) {
return type;
}

LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) {
// FIXME: Add back check for no side effects on operation.
// Currently adding it would cause the shared library build
// to fail since there would be a dependency of IR on SideEffectInterfaces
// which is cyclical.
return success();
}

LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) {
// FIXME: Add back check for no side effects on operation.
// Currently adding it would cause the shared library build
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,13 @@ def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
let results = (outs I32);
}

def TestIdempotentTraitOp
: TEST_Op<"op_idempotent_trait",
[SameOperandsAndResultType, NoSideEffect, Idempotent]> {
let arguments = (ins I32:$op1);
let results = (outs I32);
}

def TestInvolutionTraitNoOperationFolderOp
: TEST_Op<"op_involution_trait_no_operation_fold",
[SameOperandsAndResultType, NoSideEffect, Involution]> {
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/mlir-tblgen/trait.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,37 @@ func @testInhibitInvolution(%arg0: i32) -> i32 {
// CHECK: return [[OP]]
return %1: i32
}

//===----------------------------------------------------------------------===//
// Test that idempotent folding works correctly
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @testSingleIdempotent
// CHECK-SAME: ([[ARG0:%.+]]: i32)
func @testSingleIdempotent(%arg0 : i32) -> i32 {
// CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]])
%0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32
// CHECK: return [[IDEMPOTENT]]
return %0: i32
}

// CHECK-LABEL: func @testDoubleIdempotent
// CHECK-SAME: ([[ARG0:%.+]]: i32)
func @testDoubleIdempotent(%arg0: i32) -> i32 {
// CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]])
%0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32
%1 = "test.op_idempotent_trait"(%0) : (i32) -> i32
// CHECK: return [[IDEMPOTENT]]
return %1: i32
}

// CHECK-LABEL: func @testTripleIdempotent
// CHECK-SAME: ([[ARG0:%.+]]: i32)
func @testTripleIdempotent(%arg0: i32) -> i32 {
// CHECK: [[IDEMPOTENT:%.+]] = "test.op_idempotent_trait"([[ARG0]])
%0 = "test.op_idempotent_trait"(%arg0) : (i32) -> i32
%1 = "test.op_idempotent_trait"(%0) : (i32) -> i32
%2 = "test.op_idempotent_trait"(%1) : (i32) -> i32
// CHECK: return [[IDEMPOTENT]]
return %2: i32
}

0 comments on commit 7dff6b8

Please sign in to comment.