-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Refactor BF16Legalize #14405
[TIR] Refactor BF16Legalize #14405
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
5d9440a
to
d2011c4
Compare
This PR refactors BF16Legalize to enable more f32 computations. We also split the BF16Legalize into two steps. - BF16ComputeLegalize changes all computation to f32 while keeping the external BF16 storages. - BF16StorageLegalize changes all storage to u16. Now BF16 kernels accept tvm.nd.array that are created as bfloat16 type.
note: the android failure is not related to this PR. |
cc @vinx13 |
Retriggering failed Android tests |
// TODO(tvm-team): consider add native support | ||
ICHECK(!from.is_bfloat16()) << "BF16 needs to be storaged lowered first"; | ||
ICHECK(!to.is_bfloat16()) << "BF16 needs to be storaged lowered first"; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems possible to support software BF16 in upstream LLVM/MLIR world. leaving it to future work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, already left as a todo
# handle the bfloat16 so we explicitly allocate | ||
# bfloat16 arrays as input | ||
for i, param in enumerate(mod["main"].params): | ||
if param.type_annotation.dtype == "bfloat16": | ||
input_data[i] = tvm.nd.empty(input_data[i].shape, "bfloat16").copyfrom( | ||
input_data[i] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we adding ONNX tests in this PR though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needed to patch up onnx converting bf16 to uint16
This PR refactors BF16Legalize to enable more f32 computations.
We also split the BF16Legalize into two steps.
the external BF16 storages.
Now BF16 kernels accept tvm.nd.array that are created as bfloat16 type.