Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] [Relay] Automatic Mixed Precision Pass #6

Merged
merged 21 commits into from
Aug 24, 2021

Conversation

AndrewZhaoLuo
Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo commented Jun 9, 2021

Relevant Links:

https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994

  • Old discussion before the new RFC process was rolled out

apache/tvm#8069

  • Initial PR

cc @hogepodge @mbrookhart @anijain2305 @masahi

Link to tracking issue: apache/tvm#8296

@comaniac
Copy link
Contributor

comaniac commented Jun 9, 2021

Thanks for the RFC. I have two questions:

  1. How to mark/set the color (i.e., attribute) of every operator?

  2. It seems to me that if we register a casting checker instead of just a label (color), then we can simplify the algorithm a lot. Taking the case A(green) - B(gray) - C(green) as an example, if we could register a casting rule of B as follows, then we just need one traverse to know if we need cast around B:

    def amp_B(expr, args):
        a = args[0]
        if (a.dtype is float16):
          return fp16
        return fp32
    

    After all, we only need the previous nodes to determine 1) whether to use FP16 implementation, and 2) whether to insert casts. It seems to me that this pass is similar to the layout conversion pass, which uses one traverse to finish everything, so it might be possible for AMP too.

@AndrewZhaoLuo
Copy link
Contributor Author

AndrewZhaoLuo commented Jun 9, 2021

Thanks for the RFC. I have two questions:

  1. How to mark/set the color (i.e., attribute) of every operator?

  2. It seems to me that if we register a casting checker instead of just a label (color), then we can simplify the algorithm a lot. Taking the case A(green) - B(gray) - C(green) as an example, if we could register a casting rule of B as follows, then we just need one traverse to know if we need cast around B:

    def amp_B(expr, args):
        a = args[0]
        if (a.dtype is float16):
          return fp16
        return fp32
    

    After all, we only need the previous nodes to determine 1) whether to use FP16 implementation, and 2) whether to insert casts. It seems to me that this pass is similar to the layout conversion pass, which uses one traverse to finish everything, so it might be possible for AMP too.

Yep that is correct it is very similar to the layout conversion pass. This RFC has an initial PR here: apache/tvm#8069.

To answer your questions:

  1. src/relay/transforms/fp32_to_fp16.h -- DefaultFP16Colorer is the default way. But the only thing we need is a callable with type CallNode*(Color). So you could write your own colorer that does arbitrary stuff when only looking at a single node at a time.

  2. This is functionally what is done in the PR I link. It's one pass.

@comaniac
Copy link
Contributor

comaniac commented Jun 9, 2021

Thanks for the answers. I'll review the PR to get more implementation details.
One more question regarding the extensibility: can this be extended easily to support bfloat16?

@AndrewZhaoLuo
Copy link
Contributor Author

Thanks for the answers. I'll review the PR to get more implementation details.
One more question regarding the extensibility: can this be extended easily to support bfloat16?

It should be trivial (hope I don't eat my words). I'm not 100% sure of the support for bfloat16 in current relay ops however.

@AndrewZhaoLuo AndrewZhaoLuo marked this pull request as ready for review June 9, 2021 18:08
@AndrewZhaoLuo
Copy link
Contributor Author

I don't know Chris Sullivan's github handle so if someone could cc him too that would be great.

@tmoreau89
Copy link

CCing @csullivan

@comaniac
Copy link
Contributor

comaniac commented Jun 9, 2021

Thanks for the answers. I'll review the PR to get more implementation details.
One more question regarding the extensibility: can this be extended easily to support bfloat16?

It should be trivial (hope I don't eat my words). I'm not 100% sure of the support for bfloat16 in current relay ops however.

TVM has limited bfloat16 support now but it's on the way, so it would be better for this RFC to also consider this case, even the initial version may not cover it.

@AndrewZhaoLuo AndrewZhaoLuo changed the title Automatic Mixed Precision Pass RFC [RFC] [Relay] Automatic Mixed Precision Pass Jun 9, 2021
@AndrewZhaoLuo
Copy link
Contributor Author

AndrewZhaoLuo commented Jun 15, 2021

So the associated PR is getting closer to a mergeable state. Is this RFC ready for more comments?

@tqchen
Copy link
Member

tqchen commented Jul 24, 2021

cc @comaniac would be great if you can help shepherd this RFC

@tqchen tqchen added the status: need review RFC needs review label Jul 24, 2021
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concept and algorithm look good to me, but it would be better to provide more implementation/design details.

Comment on lines 102 to 103
We can support automatic mixed precision retraining though that is a much, much larger future goal. It's
good to have this in the meantime.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The answer to this question should come with a discussion of existing mechanisms used by other frameworks, such as XLA and PyTorch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Please let me know if this is sufficient. Don't have the best background on some of this stuff.

rfcs/0001-AMP_pass.md Outdated Show resolved Hide resolved
rfcs/0001-AMP_pass.md Outdated Show resolved Hide resolved
rfcs/0001-AMP_pass.md Outdated Show resolved Hide resolved
rfcs/0001-AMP_pass.md Outdated Show resolved Hide resolved
rfcs/0001-AMP_pass.md Outdated Show resolved Hide resolved
rfcs/0001-AMP_pass.md Outdated Show resolved Hide resolved
rfcs/0001-AMP_pass.md Outdated Show resolved Hide resolved
@comaniac comaniac added the status: need update RFC needs update based on feedback label Jul 27, 2021
@AndrewZhaoLuo
Copy link
Contributor Author

Thanks for driving this review @comaniac. I'll get to this later in the week.

@AndrewZhaoLuo
Copy link
Contributor Author

Going to get to this tomorrow 😬. Promise 🤞

@comaniac
Copy link
Contributor

comaniac commented Aug 4, 2021

btw, according to #17, please update the RFC number on the file name to align with this PR number.

@comaniac
Copy link
Contributor

Took a quick pass to the updated RFC. I think it's almost ready to merge as long as the last 3 comments are resolved.

@AndrewZhaoLuo
Copy link
Contributor Author

PTAL @comaniac

@AndrewZhaoLuo
Copy link
Contributor Author

@comaniac, I'll be talking about this at the TVM community meeting tomorrow so put off merging until after.

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Will merge after the community meeting if there's no objection.

@comaniac comaniac added status: accepted RFC is accepted and removed status: need review RFC needs review status: need update RFC needs update based on feedback labels Aug 18, 2021
@AndrewZhaoLuo
Copy link
Contributor Author

If there is not other objections, this will be merged on monday.

@comaniac comaniac merged commit dd2e7a8 into apache:main Aug 24, 2021
@comaniac
Copy link
Contributor

Thanks @AndrewZhaoLuo

@MeJerry215
Copy link

@AndrewZhaoLuo will it remove cast weight to float16 from graph? and make weight as float16 when build lib.
in my opinion, it will reduce the bandwidth.

@masahi
Copy link
Member

masahi commented Dec 30, 2021

@MeJerry215 Yes, casting of weight to fp16 is done at compile time by FoldConstant pass, so weights will be in fp16 at deploy time.

MichaelJKlaiber added a commit to MichaelJKlaiber/tvm-rfcs that referenced this pull request Apr 6, 2022
uma-rfc: update to questions/comments added
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: accepted RFC is accepted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants