Skip to content

Commit

Permalink
Fix huggingface GA issue for llama (#333)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

To fix #322

This PR introduces a new `lce_forward` compatible with
`transformers>=4.46.0` (after grad acc fix) while ensuring backward
compatibilty.

To be specific, i keep the original flce untouched and write a new one
for `4.46.0`. If HF version is `<4.46.0`, it will show a warning for
deprecation, and fallback to the old flce.


```python
        if transformer_version >= version.parse("4.46.0"):
            modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
        else:  # if version < 4.46.0
            logger.warning(
                "Support for transformers versions < 4.46.0 will soon be discontinued due to issues with incorrect gradient accumulation. "
                "Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191"
            )
            modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
```


For more context of grad acc fix, please see
huggingface/transformers#34191

## TODO

- [ ] broadcast the changes to all models once the effect is verified.


## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
ByronHsu authored Oct 31, 2024
1 parent 337bf9a commit e28521b
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 2 deletions.
File renamed without changes.
136 changes: 135 additions & 1 deletion src/liger_kernel/transformers/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
def lce_forward_deprecated(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -145,3 +145,137 @@ def lce_forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
cache_position=None,
num_logits_to_keep=0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""

"""
My TODO:
1. Run e2e example with hf < 4.46.0, GA on/GA off
2. Run e2e example with hf >= 4.46.0, GA on/GA off
"""

output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs[0]

if self.config.pretraining_tp > 1:
raise Exception("Liger Kernel does not support pretraining_tp!!")

logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
# We do the same thing as ForCausalLMLoss but using Liger FLCE

shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)

reduction = "sum" if "num_items_in_batch" in kwargs else "mean"
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)

loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
if reduction == "sum":
loss /= kwargs["num_items_in_batch"]

else: # if in inference mode materialize logits
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
16 changes: 15 additions & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
from functools import partial
from typing import Callable

import transformers
from packaging import version
from transformers import PreTrainedModel

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.llama import (
lce_forward_deprecated as llama_lce_forward_deprecated,
)
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
Expand All @@ -22,6 +27,8 @@
LigerSwiGLUMLP,
)

transformer_version = version.parse(transformers.__version__)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -88,7 +95,14 @@ def apply_liger_kernel_to_llama(
if cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
if transformer_version >= version.parse("4.46.0"):
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
else: # if version < 4.46.0
logger.warning(
"Support for transformers versions < 4.46.0 will soon be discontinued due to issues with incorrect gradient accumulation. "
"Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
)
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down

0 comments on commit e28521b

Please sign in to comment.