Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR, Relay] improve bfloat16 support #10112

Merged
merged 16 commits into from
Feb 24, 2022
Merged

Conversation

yangulei
Copy link
Contributor

Motivation:

We are enabling bfloat16 in BYOC-oneDNN following the path: [float32 graph] --> <AMP> --> [bfloat16 graph] --> <BYOC> --> [TVM + oneDNN module]. While some of the Passes like FoldConstant can not work for bfloat16 before the improvements below.

Changes:

  • Add runtime datatype dispatch and skip asserts for uint16 for bfloat16 compatibility.
  • Add bfloat16 casting for unary intrinsic operators to enable the graph optimization.
  • Improve the bf16_legalize module to enable bfloat16 lowering.

With those improvements, a float32 graph could be converted to bfloat16 through AMP, and then be lowered to inference in bfloat16 mode now.

Tested Models (gluoncv):

  • ResNet<18/34/50/101/152>_v1b
  • VGG<11/13/16/19>
  • VGG<11/13/16/19>_bn
  • DenseNet121
  • InceptionV3

By tested I mean I confirm it did some transformation on the graph and a forward pass could be run on CPU and matches the fp32 output somewhat. I have nothing on performance metrics or other devices yet.

As @AndrewZhaoLuo said at #8069

Pending:

The support for bfloat16 in BYOC-oneDNN is based on multi-blocking layout transform and the extensions on BYOC-oneDNN and pending.

maxpool_bf16, pool_size=(2, 2), strides=(2, 2))
mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16)
with tvm.transform.PassContext(opt_level=3):
relay.build(mod_bf16, target="llvm", params=params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also verify the output correctness?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.
The correctness of simple bfloat16 adding has been checked at

tvm.testing.assert_allclose(np_bf162np_float(c_.numpy()), res)

While, verifying the correctness of bfloat16 inference is much more complex than fist thought.
We usually use relative error and absolute error to evaluate the accuracy, but both of them could be large for a simple multiply with random inputs.
On the other hand, the MSE of the outputs (array with len=1000) of native bfloat16 inference for ResNet<18/34/50/101/152>_v1b and VGG<11/13/16/19> are about 0.5% to 1%, and less than 0.2% for BYOC-oneDNN inference. Also the envelope curve of the bfloat16 and float32 outputs are close.
I couldn't find any metrics good enough to estimate the accuracy of bfloat16 inference. Any suggestions for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Does that make sense if we calculate the a reference result using FP32 and cast it to bfloat16 for comparison? This is the only way I could think of, so I have no clue if that doesn't make sense.

@masahi @AndrewZhaoLuo do you have comments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The errors I mentioned above are using FP32 results as reference, and the bfloat16 results are casted to FP32 for comparisons.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, even fp16 is not trivial for accuracy checking. I can imagine how hard bfloat is too.

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make bfloat be a sub-type of float or a separate data type? I guess it's just a special kind of float

@@ -302,6 +302,8 @@ inline std::string DType2String(const tvm::DataType dtype) {
os << "int";
} else if (dtype.is_uint()) {
os << "uint";
} else if (dtype.is_bfloat16()) {
os << "bfloat";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
os << "bfloat";
os << "bfloat16";

@@ -1177,7 +1177,8 @@ bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< ", weights shape = " << weights->shape);
return false;
}
if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) {
if (!(predictions->dtype == weights->dtype &&
(predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we let is_float() be true for bfloat16 exprs?

Copy link
Contributor Author

@yangulei yangulei Feb 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer this way too, since they are all floating-point datatypes.
While there are some practical inconsistences so far, for example, if we let is_float() == true for bfloat16, then we cannot distinguish bfloat16 and float16 anymore as they both satisfy the condition is_float() == true && bits() == 16.

@billishyahao
Copy link
Contributor

billishyahao commented Feb 16, 2022

Tested Models (gluoncv):
ResNet<18/34/50/101/152>_v1b
VGG<11/13/16/19>
VGG<11/13/16/19>_bn
DenseNet121
InceptionV3

By tested I mean I confirm it did some transformation on the graph and a forward pass could be run on CPU and matches the fp32 output somewhat. I have nothing on performance metrics or other devices yet.

Hi @yangulei , Thanks for the patch. Could you share some details about the transformation work you had done on these graphs? Could be some op modifications?

@yangulei
Copy link
Contributor Author

Hi @billishyahao ,
You can checkout my fork. The readme file include a brief introduction about how to setup the environment and do the benchmarking. Be aware that this is not the final branch for upstreaming. You need a CPU with AVX-512 enabled to support bfloat16 functionally, and a CPU with AMX enabled to support bfloat16 natively, take a look at oneDNN document for more details.
FYI: we plan to upstream a tutorial about BYOC-oneDNN after the PRs about this are merged.

DataType dstType(kDLFloat, 32, srcType.lanes()); \
PrimExpr castX = tir::Cast(dstType, {x}, span); \
PrimExpr result = tir::Call(dstType, op, {castX}, span); \
return tir::Cast(srcType, {result}, span); \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just do tir::Cast("bfloat16", {result}, span). We use camel_case.

Can be fixed in a follow up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, Thanks. I will refine the code style in next PR.

@masahi masahi merged commit 7e2467a into apache:main Feb 24, 2022
yangulei added a commit to yangulei/tvm that referenced this pull request Mar 8, 2022
yangulei added a commit to yangulei/tvm that referenced this pull request Apr 1, 2022
yangulei added a commit to yangulei/tvm that referenced this pull request Apr 1, 2022
yangulei added a commit to yangulei/tvm that referenced this pull request Apr 1, 2022
yangulei added a commit to yangulei/tvm that referenced this pull request Apr 1, 2022
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
* update AMP table to enable ResNet50 conversion

* add runtime datatype dispatch for BFloat16

* skip asserts for uint16 for bf16 compatibility

* add bf16 cast for the unary intrinsic operators

* enable "bf16<-->fp32<-->any dtype" casting

* support inconsistent input for bf16 BIOP legalize

* add treatments for bfloat16 in if statements

* add bfloat16 dtype casts in binary OP

* delete unnecessary treatments for bfloat16

* add test for bfloat16 building

* code style

* restore the modifications in .gitignore

* restore the changes to AMP lists

* fix typos

* fix lint errors

* fix typo
yangulei added a commit to yangulei/tvm that referenced this pull request Apr 25, 2022
yangulei added a commit to yangulei/tvm that referenced this pull request May 19, 2022
yangulei added a commit to yangulei/tvm that referenced this pull request May 24, 2022
yangulei added a commit to yangulei/tvm that referenced this pull request May 26, 2022
masahi pushed a commit that referenced this pull request May 26, 2022
* refine the code style (#10112)

* support more data types in oneDNN BYOC

* consider dtype when query layout

* support more translation of blocked layout

* refine log for invalid layout transform

* reset N and C for the weights

* support multi-blocking in TransDims2Plain()

* add tests for bf16 oneDNN BYOC

* unregister 'round' OP in oneDNN BYOC

* restore the criteria for fp32 tests

* disable test_prune_dnnl_subgraph for bf16

* fix typo in dnnl.py

* delete tag::format_tag_last

* delete 'is_weight' in layout2tag()

* reuse dtype_dl2dnnl()

* fix lint errors

* change to WARNING for invalid laytout transform

* skip bf16 tests if AVX512 is unavailable
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants