Skip to content

Commit

Permalink
[mlir][emitc] Arith to EmitC: handle FP<->Integer conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
cferry-AMD committed Apr 19, 2024
1 parent 4d7f3d9 commit bf83b3f
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 1 deletion.
94 changes: 93 additions & 1 deletion mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,94 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
}
};

// Floating-point to integer conversions.
template <typename CastOp>
class FtoICastOpConversion : public OpConversionPattern<CastOp> {
public:
FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<CastOp>(typeConverter, context) {}

LogicalResult
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedFloatType(operandType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");

Type dstType = this->getTypeConverter()->convertType(castOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");

if (!emitc::isSupportedIntegerType(dstType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

// Convert to unsigned if it's the "ui" variant
// Signless is interpreted as signed, so no need to cast for "si"
Type actualResultType = dstType;
if (isa<arith::FPToUIOp>(castOp)) {
actualResultType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}

Value result = rewriter.create<emitc::CastOp>(
castOp.getLoc(), actualResultType, adaptor.getOperands());

if (isa<arith::FPToUIOp>(castOp)) {
result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
}
rewriter.replaceOp(castOp, result);

return success();
}
};

// Integer to floating-point conversions.
template <typename CastOp>
class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
public:
ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<CastOp>(typeConverter, context) {}

LogicalResult
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Vectors in particular are not supported
Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedIntegerType(operandType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");

Type dstType = this->getTypeConverter()->convertType(castOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");

if (!emitc::isSupportedFloatType(dstType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

// Convert to unsigned if it's the "ui" variant
// Signless is interpreted as signed, so no need to cast for "si"
Type actualOperandType = operandType;
if (isa<arith::UIToFPOp>(castOp)) {
actualOperandType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
Value fpCastOperand = adaptor.getIn();
if (actualOperandType != operandType) {
fpCastOperand = rewriter.template create<emitc::CastOp>(
castOp.getLoc(), actualOperandType, fpCastOperand);
}
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);

return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -222,7 +310,11 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
CmpIOpConversion,
SelectOpConversion
SelectOpConversion,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
FtoICastOpConversion<arith::FPToUIOp>
>(typeConverter, ctx);
// clang-format on
}
48 changes: 48 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s

func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32>
return %t: tensor<5xi32>
}

// -----

func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32>
return %t: vector<5xi32>
}

// -----

func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : bf16 to i32
return %t: i32
}

// -----

func.func @arith_cast_f16(%arg0: f16) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : f16 to i32
return %t: i32
}


// -----

func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
%t = arith.sitofp %arg0 : i32 to bf16
return %t: bf16
}

// -----

func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
%t = arith.sitofp %arg0 : i32 to f16
return %t: f16
}
36 changes: 36 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,39 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {

return
}

// -----

func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
// CHECK: emitc.cast %arg0 : f32 to i32
%0 = arith.fptosi %arg0 : f32 to i32

// CHECK: emitc.cast %arg1 : f64 to i32
%1 = arith.fptosi %arg1 : f64 to i32

// CHECK: emitc.cast %arg0 : f32 to i16
%2 = arith.fptosi %arg0 : f32 to i16

// CHECK: emitc.cast %arg1 : f64 to i16
%3 = arith.fptosi %arg1 : f64 to i16

// CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32
// CHECK: emitc.cast %[[CAST0]] : ui32 to i32
%4 = arith.fptoui %arg0 : f32 to i32

return
}

func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
// CHECK: emitc.cast %arg0 : i8 to f32
%0 = arith.sitofp %arg0 : i8 to f32

// CHECK: emitc.cast %arg1 : i64 to f32
%1 = arith.sitofp %arg1 : i64 to f32

// CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
// CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
%2 = arith.uitofp %arg0 : i8 to f32

return
}

0 comments on commit bf83b3f

Please sign in to comment.