Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
josel-amd authored and cferry-AMD committed Apr 12, 2024
1 parent 62450d3 commit 7a17cb7
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 25 deletions.
55 changes: 50 additions & 5 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APInt.h"

using namespace mlir;

Expand Down Expand Up @@ -79,6 +80,13 @@ Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
firstIsNaN, secondIsNaN);
}

emitc::ConstantOp getConstant(OpBuilder rewriter, Location loc,
llvm::APInt val) {
auto type = rewriter.getIntegerType(val.getBitWidth());
return rewriter.create<emitc::ConstantOp>(loc, type,
rewriter.getIntegerAttr(type, val));
}

class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand All @@ -96,6 +104,39 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
bool unordered = false;
emitc::CmpPredicate predicate;
switch (op.getPredicate()) {
case arith::CmpFPredicate::AlwaysFalse: {
auto constant = getConstant(rewriter, op->getLoc(), llvm::APInt(1, 0));
rewriter.replaceOp(op, constant);
return success();
}
case arith::CmpFPredicate::OEQ:
unordered = false;
predicate = emitc::CmpPredicate::eq;
break;
case arith::CmpFPredicate::OGT:
// ordered and greater than
unordered = false;
predicate = emitc::CmpPredicate::gt;
break;
case arith::CmpFPredicate::OGE:
unordered = false;
predicate = emitc::CmpPredicate::ge;
break;
case arith::CmpFPredicate::OLT:
unordered = false;
predicate = emitc::CmpPredicate::lt;
break;
case arith::CmpFPredicate::ONE:
unordered = false;
predicate = emitc::CmpPredicate::ne;
break;
case arith::CmpFPredicate::ORD: {
// ordered, i.e. none of the operands is NaN
auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
adaptor.getRhs());
rewriter.replaceOp(op, cmp);
return success();
}
case arith::CmpFPredicate::UEQ:
// unordered or equal
unordered = true;
Expand All @@ -121,18 +162,22 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
unordered = true;
predicate = emitc::CmpPredicate::le;
break;
case arith::CmpFPredicate::UNE:
unordered = true;
predicate = emitc::CmpPredicate::ne;
break;
case arith::CmpFPredicate::UNO: {
// unordered, i.e. either operand is nan
auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
adaptor.getRhs());
rewriter.replaceOp(op, cmp);
return success();
}
case arith::CmpFPredicate::OGT:
// ordered and greater than
unordered = false;
predicate = emitc::CmpPredicate::gt;
break;
case arith::CmpFPredicate::AlwaysTrue: {
auto constant = getConstant(rewriter, op->getLoc(), llvm::APInt(1, 1));
rewriter.replaceOp(op, constant);
return success();
}
default:
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot match predicate ");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@ func.func @arith_cmpf_vector(%arg0: vector<5xf32>, %arg1: vector<5xf32>) -> vect

// -----

func.func @arith_cmpf_ordered(%arg0: f32, %arg1: f32) -> i1 {
// expected-error @+1 {{failed to legalize operation 'arith.cmpf'}}
%oge = arith.cmpf oge, %arg0, %arg1 : f32
return %oge: i1
}

// -----

func.func @arith_cast_f32(%arg0: f32) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : f32 to i32
Expand Down
134 changes: 122 additions & 12 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,105 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -

// -----

func.func @arith_cmpf_false(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_false
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
// CHECK-DAG: [[False:[^ ]*]] = "emitc.constant"() <{value = false}> : () -> i1
%ueq = arith.cmpf false, %arg0, %arg1 : f32
// CHECK: return [[False]]
return %ueq: i1
}

// -----

func.func @arith_cmpf_oeq(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_oeq
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
// CHECK-DAG: [[EQ:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
// CHECK-DAG: [[OEQ:[^ ]*]] = emitc.logical_and [[Ordered]], [[EQ]] : i1, i1
%ueq = arith.cmpf oeq, %arg0, %arg1 : f32
// CHECK: return [[OEQ]]
return %ueq: i1
}

// -----

func.func @arith_cmpf_ogt(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_ogt
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
// CHECK-DAG: [[GT:[^ ]*]] = emitc.cmp gt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[OrderedArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
// CHECK-DAG: [[OrderedArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[OrderedArg0]], [[OrderedArg1]] : i1, i1
// CHECK-DAG: [[OGT:[^ ]*]] = emitc.logical_and [[Ordered]], [[GT]] : i1, i1
%ogt = arith.cmpf ogt, %arg0, %arg1 : f32
// CHECK: return [[OGT]]
return %ogt: i1
}

// -----

func.func @arith_cmpf_oge(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_oge
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
// CHECK-DAG: [[GE:[^ ]*]] = emitc.cmp ge, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
// CHECK-DAG: [[OGE:[^ ]*]] = emitc.logical_and [[Ordered]], [[GE]] : i1, i1
%ueq = arith.cmpf oge, %arg0, %arg1 : f32
// CHECK: return [[OGE]]
return %ueq: i1
}

// -----

func.func @arith_cmpf_olt(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_olt
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
// CHECK-DAG: [[LT:[^ ]*]] = emitc.cmp lt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
// CHECK-DAG: [[UEQ:[^ ]*]] = emitc.logical_and [[Ordered]], [[LT]] : i1, i1
%ueq = arith.cmpf olt, %arg0, %arg1 : f32
// CHECK: return [[UEQ]]
return %ueq: i1
}

// -----

func.func @arith_cmpf_one(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_one
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
// CHECK-DAG: [[NEQ:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
// CHECK-DAG: [[ONE:[^ ]*]] = emitc.logical_and [[Ordered]], [[NEQ]] : i1, i1
%ueq = arith.cmpf one, %arg0, %arg1 : f32
// CHECK: return [[ONE]]
return %ueq: i1
}

// -----

func.func @arith_cmpf_ord(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_ord
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
// CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1
%ueq = arith.cmpf ord, %arg0, %arg1 : f32
// CHECK: return [[Ordered]]
return %ueq: i1
}

// -----

func.func @arith_cmpf_ueq(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_ueq
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
Expand Down Expand Up @@ -133,30 +232,41 @@ func.func @arith_cmpf_ule(%arg0: f32, %arg1: f32) -> i1 {

// -----

func.func @arith_cmpf_une(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_une
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
// CHECK-DAG: [[NEQ:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1
// CHECK-DAG: [[UNE:[^ ]*]] = emitc.logical_or [[Unordered]], [[NEQ]] : i1, i1
%une = arith.cmpf une, %arg0, %arg1 : f32
// CHECK: return [[UNE]]
return %une: i1
}

// -----

func.func @arith_cmpf_uno(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_uno
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
// CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
// CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1
%2 = arith.cmpf uno, %arg0, %arg1 : f32
%uno = arith.cmpf uno, %arg0, %arg1 : f32
// CHECK: return [[Unordered]]
return %2: i1
return %uno: i1
}

// -----

func.func @arith_cmpf_ogt(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_ogt
func.func @arith_cmpf_true(%arg0: f32, %arg1: f32) -> i1 {
// CHECK-LABEL: arith_cmpf_true
// CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32)
// CHECK-DAG: [[GT:[^ ]*]] = emitc.cmp gt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[OrderedArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1
// CHECK-DAG: [[OrderedArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1
// CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[OrderedArg0]], [[OrderedArg1]] : i1, i1
// CHECK-DAG: [[OGT:[^ ]*]] = emitc.logical_and [[Ordered]], [[GT]] : i1, i1
%ule = arith.cmpf ogt, %arg0, %arg1 : f32
// CHECK: return [[OGT]]
return %ule: i1
// CHECK-DAG: [[True:[^ ]*]] = "emitc.constant"() <{value = true}> : () -> i1
%ueq = arith.cmpf true, %arg0, %arg1 : f32
// CHECK: return [[True]]
return %ueq: i1
}

// -----
Expand Down

0 comments on commit 7a17cb7

Please sign in to comment.