-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR, Relay] improve bfloat16 support #10112
Conversation
maxpool_bf16, pool_size=(2, 2), strides=(2, 2)) | ||
mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16) | ||
with tvm.transform.PassContext(opt_level=3): | ||
relay.build(mod_bf16, target="llvm", params=params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also verify the output correctness?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point.
The correctness of simple bfloat16 adding
has been checked at
tvm.testing.assert_allclose(np_bf162np_float(c_.numpy()), res) |
While, verifying the correctness of bfloat16 inference is much more complex than fist thought.
We usually use relative error and absolute error to evaluate the accuracy, but both of them could be large for a simple
multiply
with random inputs.On the other hand, the MSE of the outputs (array with len=1000) of native bfloat16 inference for ResNet<18/34/50/101/152>_v1b and VGG<11/13/16/19> are about 0.5% to 1%, and less than 0.2% for BYOC-oneDNN inference. Also the envelope curve of the bfloat16 and float32 outputs are close.
I couldn't find any metrics good enough to estimate the accuracy of bfloat16 inference. Any suggestions for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Does that make sense if we calculate the a reference result using FP32 and cast it to bfloat16 for comparison? This is the only way I could think of, so I have no clue if that doesn't make sense.
@masahi @AndrewZhaoLuo do you have comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The errors I mentioned above are using FP32 results as reference, and the bfloat16 results are casted to FP32 for comparisons.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, even fp16 is not trivial for accuracy checking. I can imagine how hard bfloat is too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make bfloat
be a sub-type of float
or a separate data type? I guess it's just a special kind of float
@@ -302,6 +302,8 @@ inline std::string DType2String(const tvm::DataType dtype) { | |||
os << "int"; | |||
} else if (dtype.is_uint()) { | |||
os << "uint"; | |||
} else if (dtype.is_bfloat16()) { | |||
os << "bfloat"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os << "bfloat"; | |
os << "bfloat16"; |
@@ -1177,7 +1177,8 @@ bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, | |||
<< ", weights shape = " << weights->shape); | |||
return false; | |||
} | |||
if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) { | |||
if (!(predictions->dtype == weights->dtype && | |||
(predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we let is_float()
be true for bfloat16
exprs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer this way too, since they are all floating-point datatypes.
While there are some practical inconsistences so far, for example, if we let is_float() == true
for bfloat16
, then we cannot distinguish bfloat16
and float16
anymore as they both satisfy the condition is_float() == true && bits() == 16
.
Tested Models (gluoncv):
Hi @yangulei , Thanks for the patch. Could you share some details about the transformation work you had done on these graphs? Could be some op modifications? |
Hi @billishyahao , |
DataType dstType(kDLFloat, 32, srcType.lanes()); \ | ||
PrimExpr castX = tir::Cast(dstType, {x}, span); \ | ||
PrimExpr result = tir::Call(dstType, op, {castX}, span); \ | ||
return tir::Cast(srcType, {result}, span); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just do tir::Cast("bfloat16", {result}, span)
. We use camel_case
.
Can be fixed in a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, Thanks. I will refine the code style in next PR.
* update AMP table to enable ResNet50 conversion * add runtime datatype dispatch for BFloat16 * skip asserts for uint16 for bf16 compatibility * add bf16 cast for the unary intrinsic operators * enable "bf16<-->fp32<-->any dtype" casting * support inconsistent input for bf16 BIOP legalize * add treatments for bfloat16 in if statements * add bfloat16 dtype casts in binary OP * delete unnecessary treatments for bfloat16 * add test for bfloat16 building * code style * restore the modifications in .gitignore * restore the changes to AMP lists * fix typos * fix lint errors * fix typo
* refine the code style (#10112) * support more data types in oneDNN BYOC * consider dtype when query layout * support more translation of blocked layout * refine log for invalid layout transform * reset N and C for the weights * support multi-blocking in TransDims2Plain() * add tests for bf16 oneDNN BYOC * unregister 'round' OP in oneDNN BYOC * restore the criteria for fp32 tests * disable test_prune_dnnl_subgraph for bf16 * fix typo in dnnl.py * delete tag::format_tag_last * delete 'is_weight' in layout2tag() * reuse dtype_dl2dnnl() * fix lint errors * change to WARNING for invalid laytout transform * skip bf16 tests if AVX512 is unavailable
Motivation:
We are enabling bfloat16 in BYOC-oneDNN following the path: [float32 graph] --> <AMP> --> [bfloat16 graph] --> <BYOC> --> [TVM + oneDNN module]. While some of the Passes like
FoldConstant
can not work for bfloat16 before the improvements below.Changes:
With those improvements, a float32 graph could be converted to bfloat16 through AMP, and then be lowered to inference in bfloat16 mode now.
Tested Models (gluoncv):
As @AndrewZhaoLuo said at #8069
Pending:
The support for bfloat16 in BYOC-oneDNN is based on multi-blocking layout transform and the extensions on BYOC-oneDNN and pending.