Skip to content

Commit

Permalink
refine the code style (apache#10112)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangulei committed Apr 25, 2022
1 parent 822d863 commit 8195136
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -862,18 +862,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
Span span = Span());

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
if (x.dtype().is_bfloat16()) { \
DataType srcType = x.dtype(); \
DataType dstType(kDLFloat, 32, srcType.lanes()); \
PrimExpr castX = tir::Cast(dstType, {x}, span); \
PrimExpr result = tir::Call(dstType, op, {castX}, span); \
return tir::Cast(srcType, {result}, span); \
} else { \
return tir::Call(x.dtype(), op, {x}, span); \
} \
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
if (x.dtype().is_bfloat16()) { \
DataType bf16_dtype = x.dtype(); \
DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \
PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \
return tir::Cast(bf16_dtype, {result_fp32}, span); \
} else { \
return tir::Call(x.dtype(), op, {x}, span); \
} \
}

TVM_DECLARE_INTRIN_UNARY(exp);
Expand Down

0 comments on commit 8195136

Please sign in to comment.