diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 5b63016d2f9d8..905c67f1c5b0a 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -862,18 +862,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s Span span = Span()); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op& op = Op::Get("tir." #OpName); \ - if (x.dtype().is_bfloat16()) { \ - DataType srcType = x.dtype(); \ - DataType dstType(kDLFloat, 32, srcType.lanes()); \ - PrimExpr castX = tir::Cast(dstType, {x}, span); \ - PrimExpr result = tir::Call(dstType, op, {castX}, span); \ - return tir::Cast(srcType, {result}, span); \ - } else { \ - return tir::Call(x.dtype(), op, {x}, span); \ - } \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op& op = Op::Get("tir." #OpName); \ + if (x.dtype().is_bfloat16()) { \ + DataType bf16_dtype = x.dtype(); \ + DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ + PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \ + PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \ + return tir::Cast(bf16_dtype, {result_fp32}, span); \ + } else { \ + return tir::Call(x.dtype(), op, {x}, span); \ + } \ } TVM_DECLARE_INTRIN_UNARY(exp);