Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Refactor BF16Legalize #14405

Merged
merged 1 commit into from
Mar 28, 2023
Merged

[TIR] Refactor BF16Legalize #14405

merged 1 commit into from
Mar 28, 2023

Conversation

tqchen
Copy link
Member

@tqchen tqchen commented Mar 26, 2023

This PR refactors BF16Legalize to enable more f32 computations.
We also split the BF16Legalize into two steps.

  • BF16ComputeLegalize changes all computation to f32 while keeping
    the external BF16 storages.
  • BF16StorageLegalize changes all storage to u16.

Now BF16 kernels accept tvm.nd.array that are created as bfloat16 type.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 26, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@tqchen
Copy link
Member Author

tqchen commented Mar 26, 2023

cc @yangulei @Menooker

@tqchen tqchen force-pushed the bf16-refactor branch 5 times, most recently from 5d9440a to d2011c4 Compare March 27, 2023 00:29
This PR refactors BF16Legalize to enable more f32 computations.
We also split the BF16Legalize into two steps.

- BF16ComputeLegalize changes all computation to f32 while keeping
  the external BF16 storages.
- BF16StorageLegalize changes all storage to u16.

Now BF16 kernels accept tvm.nd.array that are created as bfloat16 type.
@tqchen
Copy link
Member Author

tqchen commented Mar 27, 2023

note: the android failure is not related to this PR.

@tqchen
Copy link
Member Author

tqchen commented Mar 27, 2023

cc @vinx13

@junrushao
Copy link
Member

Retriggering failed Android tests

Comment on lines +831 to +834
// TODO(tvm-team): consider add native support
ICHECK(!from.is_bfloat16()) << "BF16 needs to be storaged lowered first";
ICHECK(!to.is_bfloat16()) << "BF16 needs to be storaged lowered first";

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems possible to support software BF16 in upstream LLVM/MLIR world. leaving it to future work

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, already left as a todo

Comment on lines +102 to +108
# handle the bfloat16 so we explicitly allocate
# bfloat16 arrays as input
for i, param in enumerate(mod["main"].params):
if param.type_annotation.dtype == "bfloat16":
input_data[i] = tvm.nd.empty(input_data[i].shape, "bfloat16").copyfrom(
input_data[i]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we adding ONNX tests in this PR though?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed to patch up onnx converting bf16 to uint16

@junrushao junrushao merged commit a0edf24 into apache:main Mar 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants