Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exclude embedding in MFU computation #280

Merged
merged 1 commit into from
May 2, 2024

Conversation

tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Apr 26, 2024

Stack from ghstack (oldest at bottom):

Per suggestion in #274:

This PR removes embedding from number of parameters calculation, because embedding op doesn't do matmul.

This PR follow the industry convention (PaLM paper, nanoGPT, Megatron, etc.) and uses a factor of 12 in the self-attention part, even when causal attention is enabled.

tianyu-l added a commit that referenced this pull request Apr 26, 2024
ghstack-source-id: 9daa99020c76fdfe429b6a9ee6d44fd1dd319fc3
Pull Request resolved: #280
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 26, 2024
@tianyu-l tianyu-l requested review from wanchaol, Chillee and awgu April 26, 2024 22:49
@tianyu-l tianyu-l mentioned this pull request Apr 26, 2024
@tianyu-l tianyu-l requested a review from drisspg April 26, 2024 22:58
@ad8e
Copy link

ad8e commented Apr 26, 2024

FlashAttention makes use of the causal mask to do half the work, so one of my friends got >100% MFU when using the 12 factor rather than 7. Common options for multipliers in the causal setting are:

6: This is theoretically the lowest FLOP count possible, but it's not what an efficient implementation would use. I only know of one implementation which calculates it this way.
7: This is Flash Attention, and how they benchmark themselves. This is the most accurate counter to engineer against.
12: This is common, but not very useful in my experience, as it's easy to make the MFU go way up by changing context length.

@Chillee
Copy link

Chillee commented Apr 27, 2024

@ad8e imo, I would either use 6 or 12. MFU was originally intended to exclude recomputation flops (from activation checkpointing), it seems somewhat strange to me to reinclude it here. In addition, other FlashAttention implementations (like say, Triton's) actually end up with a factor of 9 for FLOPs.

My argument for 12 would be that, if you use 6, then you need to start to be quite consistent about taking into account sparsity (for example, let's say we add sliding window attention). Otherwise, you end up back in the same situation you're referring to, where increasing sequence length results in overly increased MFU.

Perhaps unobviously, flop counting is a somewhat subjective enterprise. For FlashAttentionv2 itself it makes sense to benchmark with 7, as what it cares about is "how much room is there to optimize this kernel".

I think, the question is whether you care more about being "invariant" across sequence lengths or "attention patterns". I would probably agree that sequence length is the more important factor, so you've convinced me it should be 6 :)

@tianyu-l
Copy link
Contributor Author

tianyu-l commented May 1, 2024

@ad8e

I checked some other mainstream repos on how MFU is computed. From what I can tell, most (if not all) of them are using 12. For example:

Since ultimately MFU is a derived metric from token-per-second (formula) and there doesn't exist a consensus on what formula to use, we feel it's safer to follow the industry convention to use 12, unless the community change it together. One can always use token-per-second for more direct comparisons.

@tianyu-l tianyu-l merged commit 8427d9a into gh/tianyu-l/9/base May 2, 2024
4 checks passed
tianyu-l added a commit that referenced this pull request May 2, 2024
ghstack-source-id: 9daa99020c76fdfe429b6a9ee6d44fd1dd319fc3
Pull Request resolved: #280
@tianyu-l tianyu-l deleted the gh/tianyu-l/9/head branch May 2, 2024 00:14
@ad8e
Copy link

ad8e commented May 2, 2024

I checked some other mainstream repos on how MFU is computed. From what I can tell, most (if not all) of them are using 12. For example:

If you wish to base your decision on what is most widely used, I agree that 12 is the most common number. Only Flash Attention uses 7 (and me), and only one codebase I've seen uses 6.

@awgu awgu mentioned this pull request Jul 11, 2024
tianyu-l added a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
ghstack-source-id: 9daa99020c76fdfe429b6a9ee6d44fd1dd319fc3
Pull Request resolved: pytorch#280
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
ghstack-source-id: 9daa99020c76fdfe429b6a9ee6d44fd1dd319fc3
Pull Request resolved: pytorch#280
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants