From d42d8aacbbfa2ca3a48429586a2ce1f110ab6a5e Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Tue, 8 Mar 2022 08:22:23 +0800 Subject: [PATCH] refine the code style (#10112) --- dev_tvm | 1 + include/tvm/tir/op.h | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 12 deletions(-) create mode 160000 dev_tvm diff --git a/dev_tvm b/dev_tvm new file mode 160000 index 000000000000..53dca99ee0c3 --- /dev/null +++ b/dev_tvm @@ -0,0 +1 @@ +Subproject commit 53dca99ee0c3a48171613b0446a0a0f3609354a8 diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 9c3ea135c68d..cdecafa001e6 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -862,18 +862,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s Span span = Span()); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op& op = Op::Get("tir." #OpName); \ - if (x.dtype().is_bfloat16()) { \ - DataType srcType = x.dtype(); \ - DataType dstType(kDLFloat, 32, srcType.lanes()); \ - PrimExpr castX = tir::Cast(dstType, {x}, span); \ - PrimExpr result = tir::Call(dstType, op, {castX}, span); \ - return tir::Cast(srcType, {result}, span); \ - } else { \ - return tir::Call(x.dtype(), op, {x}, span); \ - } \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op& op = Op::Get("tir." #OpName); \ + if (x.dtype().is_bfloat16()) { \ + DataType bf16_dtype = x.dtype(); \ + DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ + PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \ + PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \ + return tir::Cast(bf16_dtype, {result_fp32}, span); \ + } else { \ + return tir::Call(x.dtype(), op, {x}, span); \ + } \ } TVM_DECLARE_INTRIN_UNARY(exp);