Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Summary The analogous `logits.float()` calls were moved in the Hugging Face modeling source code to be inside the `if labels is not None` block to avoid upcasting logits unless they are being used in a loss calculation; this avoids a memory spike during inference if the model is in lower precision. * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/llama/modeling_llama.py#L1211-L1212 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/mixtral/modeling_mixtral.py#L1329-L1330 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/phi3/modeling_phi3.py#L1303-L1304 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/qwen2/modeling_qwen2.py#L1206-L1207 Some of your models already have this change: https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/mistral.py#L114-L116 https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/gemma.py#L114-L116 See also: * huggingface/transformers#30860 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
- Loading branch information