-
Notifications
You must be signed in to change notification settings - Fork 267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
exclude embedding in MFU computation #280
Conversation
[ghstack-poisoned]
ghstack-source-id: 9daa99020c76fdfe429b6a9ee6d44fd1dd319fc3 Pull Request resolved: #280
FlashAttention makes use of the causal mask to do half the work, so one of my friends got >100% MFU when using the 12 factor rather than 7. Common options for multipliers in the causal setting are: 6: This is theoretically the lowest FLOP count possible, but it's not what an efficient implementation would use. I only know of one implementation which calculates it this way. |
@ad8e imo, I would either use 6 or 12. MFU was originally intended to exclude recomputation flops (from activation checkpointing), it seems somewhat strange to me to reinclude it here. In addition, other FlashAttention implementations (like say, Triton's) actually end up with a factor of 9 for FLOPs. My argument for 12 would be that, if you use 6, then you need to start to be quite consistent about taking into account sparsity (for example, let's say we add sliding window attention). Otherwise, you end up back in the same situation you're referring to, where increasing sequence length results in overly increased MFU. Perhaps unobviously, flop counting is a somewhat subjective enterprise. For FlashAttentionv2 itself it makes sense to benchmark with 7, as what it cares about is "how much room is there to optimize this kernel". I think, the question is whether you care more about being "invariant" across sequence lengths or "attention patterns". I would probably agree that sequence length is the more important factor, so you've convinced me it should be 6 :) |
I checked some other mainstream repos on how MFU is computed. From what I can tell, most (if not all) of them are using 12. For example:
Since ultimately MFU is a derived metric from token-per-second (formula) and there doesn't exist a consensus on what formula to use, we feel it's safer to follow the industry convention to use 12, unless the community change it together. One can always use token-per-second for more direct comparisons. |
ghstack-source-id: 9daa99020c76fdfe429b6a9ee6d44fd1dd319fc3 Pull Request resolved: #280
If you wish to base your decision on what is most widely used, I agree that 12 is the most common number. Only Flash Attention uses 7 (and me), and only one codebase I've seen uses 6. |
ghstack-source-id: 9daa99020c76fdfe429b6a9ee6d44fd1dd319fc3 Pull Request resolved: pytorch#280
ghstack-source-id: 9daa99020c76fdfe429b6a9ee6d44fd1dd319fc3 Pull Request resolved: pytorch#280
Stack from ghstack (oldest at bottom):
Per suggestion in #274:
This PR removes embedding from number of parameters calculation, because embedding op doesn't do matmul.
This PR follow the industry convention (PaLM paper, nanoGPT, Megatron, etc.) and uses a factor of 12 in the self-attention part, even when causal attention is enabled.