Skip to content

Commit

Permalink
Add Llama Flax Implementation (#24587)
Browse files Browse the repository at this point in the history
* Copies `modeling_flax_gpt_neo.py` to start

* MLP Block. WIP Attention and Block

* Adds Flax implementation of `LlamaMLP`
Validated with in-file test.
Some slight numeric differences, but assuming it isn't an issue

* Adds `FlaxLlamaRMSNorm` layer
`flax.linen` includes `RMSNorm` layer but not necessarily in all
versions. Hence, we add in-file.

* Adds FlaxLlamaAttention
Copied from GPT-J as it has efficient caching implementation as well as
rotary embeddings.
Notice numerically different, but not by a huge amount. Needs
investigating

* Adds `FlaxLlamaDecoderLayer`
numerically inaccurate, debugging..

* debugging rotary mismatch
gptj uses interleaved whilst llama uses contiguous
i think they match now but still final result is wrong.
maybe drop back to just debugging attention layer?

* fixes bug with decoder layer
still somewhat numerically inaccurate, but close enough for now

* adds markers for what to implement next
the structure here diverges a lot from the PT version.
not a big fan of it, but just get something working for now

* implements `FlaxLlamaBlockCollection`]
tolerance must be higher than expected, kinda disconcerting

* Adds `FlaxLlamaModule`
equivalent PyTorch model is `LlamaModel`
yay! a language model🤗

* adds `FlaxLlamaForCausalLMModule`
equivalent to `LlamaForCausalLM`
still missing returning dict or tuple, will add later

* start porting pretrained wrappers
realised it probably needs return dict as a prereq

* cleanup, quality, style

* readds `return_dict` and model output named tuples

* (tentatively) pretrained wrappers work 🔥

* fixes numerical mismatch in `FlaxLlamaRMSNorm`
seems `jax.lax.rsqrt` does not match `torch.sqrt`.
manually computing `1 / jax.numpy.sqrt` results in matching values.

* [WIP] debugging numerics

* numerical match
I think issue was accidental change of backend. forcing CPU fixes test.
We expect some mismatch on GPU.

* adds in model and integration tests for Flax Llama
summary of failing:
- mul invalid combination of dimensions
- one numerical mismatch
- bf16 conversion (maybe my local backend issue)
- params are not FrozenDict

* adds missing TYPE_CHECKING import and `make fixup`

* adds back missing docstrings
needs review on quality of docstrings, not sure what is required.
Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO

* commenting out equivalence test as can just use common

* debugging

* Fixes bug where mask and pos_ids were swapped in pretrained models
This results in all tests passing now 🔥

* cleanup of modeling file

* cleanup of test file

* Resolving simpler review comments

* addresses more minor review comments

* fixing introduced pytest errors from review

* wip additional slow tests

* wip tests
need to grab a GPU machine to get real logits for comparison
otherwise, slow tests should be okay

* `make quality`, `make style`

* adds slow integration tests
- checking logits
- checking hidden states
- checking generation outputs

* `make fix-copies`

* fix mangled function following `make fix-copies`

* adds missing type checking imports

* fixes missing parameter checkpoint warning

* more finegrained 'Copied from' tags
avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING`

* swaps import guards
??? how did these get swapped initially?

* removing `inv_freq` again as pytorch version has now removed

* attempting to get CI to pass

* adds doc entries for llama flax models

* fixes typo in __init__.py imports

* adds back special equivalence tests
these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version

* overrides tests with dummy to see if CI passes
need to fill in these tests later

* adds my contribution to docs

* `make style; make quality`

* replaces random masking with fixed to work with flax version

* `make quality; make style`

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* updates `x`->`tensor` in `rotate_half`

* addresses smaller review comments

* Update docs/source/en/model_doc/llama.md

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* adds integration test class

* adds `dtype` to rotary embedding to cast outputs

* adds type to flax llama rotary layer

* `make style`

* `make fix-copies`

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* applies suggestions from review

* Update modeling_flax_llama.py

* `make fix-copies`

* Update tests/models/llama/test_modeling_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* fixes shape mismatch in FlaxLlamaMLP

* applies some suggestions from reviews

* casts attn output logits to f32 regardless of dtype

* adds attn bias using `LlamaConfig.attention_bias`

* adds Copied From comments to Flax Llama test

* mistral and persimmon test change -copy from llama

* updates docs index

* removes Copied from in tests

it was preventing `make fix-copies` from succeeding

* quality and style

* ignores FlaxLlama input docstring

* adds revision to `_CHECKPOINT_FOR_DOC`

* repo consistency and quality

* removes unused import

* removes copied from from Phi test

now diverges from llama tests following FlaxLlama changes

* adds `_REAL_CHECKPOINT_FOR_DOC`

* removes refs from pr tests

* reformat to make ruff happy

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
  • Loading branch information
vvvm23 and sanchit-gandhi authored Dec 7, 2023
1 parent 7fc8072 commit 75336c1
Show file tree
Hide file tree
Showing 17 changed files with 1,071 additions and 16 deletions.
8 changes: 4 additions & 4 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
<!--Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -94,7 +94,7 @@ Flax), PyTorch, and/or TensorFlow.
| [CLIPSeg](model_doc/clipseg) ||||
| [CLVP](model_doc/clvp) ||||
| [CodeGen](model_doc/codegen) ||||
| [CodeLlama](model_doc/code_llama) ||| |
| [CodeLlama](model_doc/code_llama) ||| |
| [Conditional DETR](model_doc/conditional_detr) ||||
| [ConvBERT](model_doc/convbert) ||||
| [ConvNeXT](model_doc/convnext) ||||
Expand Down Expand Up @@ -167,8 +167,8 @@ Flax), PyTorch, and/or TensorFlow.
| [LED](model_doc/led) ||||
| [LeViT](model_doc/levit) ||||
| [LiLT](model_doc/lilt) ||||
| [LLaMA](model_doc/llama) ||| |
| [Llama2](model_doc/llama2) ||| |
| [LLaMA](model_doc/llama) ||| |
| [Llama2](model_doc/llama2) ||| |
| [Longformer](model_doc/longformer) ||||
| [LongT5](model_doc/longt5) ||||
| [LUKE](model_doc/luke) ||||
Expand Down
13 changes: 13 additions & 0 deletions docs/source/en/model_doc/llama.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ come in several checkpoints they each contain a part of each weight of the model

- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.

This model was contributed by [zphang](https://huggingface.co/zphang) with contributions from [BlackSamorez](https://huggingface.co/BlackSamorez). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox). The original code of the authors can be found [here](https://github.com/facebookresearch/llama). The Flax version of the implementation was contributed by [afmck](https://huggingface.co/afmck) with the code in the implementation based on Hugging Face's Flax GPT-Neo.


Based on the original LLaMA model, Meta AI has released some follow-up works:

- **Llama2**: Llama2 is an improved version of Llama with some architectural tweaks (Grouped Query Attention), and is pre-trained on 2Trillion tokens. Refer to the documentation of Llama2 which can be found [here](llama2).
Expand Down Expand Up @@ -112,3 +115,13 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h

[[autodoc]] LlamaForSequenceClassification
- forward

## FlaxLlamaModel

[[autodoc]] FlaxLlamaModel
- __call__

## FlaxLlamaForCausalLM

[[autodoc]] FlaxLlamaForCausalLM
- __call__
6 changes: 6 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4554,6 +4554,7 @@
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
)
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
_import_structure["models.llama"].extend(["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"])
_import_structure["models.longt5"].extend(
[
"FlaxLongT5ForConditionalGeneration",
Expand Down Expand Up @@ -8631,6 +8632,11 @@
FlaxGPTJModel,
FlaxGPTJPreTrainedModel,
)
from .models.llama import (
FlaxLlamaForCausalLM,
FlaxLlamaModel,
FlaxLlamaPreTrainedModel,
)
from .models.longt5 import (
FlaxLongT5ForConditionalGeneration,
FlaxLongT5Model,
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,14 +1267,17 @@ def overwrite_call_docstring(model_class, docstring):
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)


def append_call_sample_docstring(model_class, checkpoint, output_type, config_class, mask=None, revision=None):
def append_call_sample_docstring(
model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = add_code_sample_docstrings(
checkpoint=checkpoint,
output_type=output_type,
config_class=config_class,
model_cls=model_class.__name__,
revision=revision,
real_checkpoint=real_checkpoint,
)(model_class.__call__)


Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
("gptj", "FlaxGPTJModel"),
("llama", "FlaxLlamaModel"),
("longt5", "FlaxLongT5Model"),
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
Expand Down Expand Up @@ -146,6 +147,7 @@
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("llama", "FlaxLlamaForCausalLM"),
("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def _convert_to_standard_cache(

@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/fuyu/image_processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@


def make_list_of_list_of_images(
images: Union[List[List[ImageInput]], List[ImageInput], ImageInput]
images: Union[List[List[ImageInput]], List[ImageInput], ImageInput],
) -> List[List[ImageInput]]:
if is_valid_image(images):
return [[images]]
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
Expand Down Expand Up @@ -55,6 +56,14 @@
"LlamaForSequenceClassification",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_llama"] = ["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"]


if TYPE_CHECKING:
from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
Expand Down Expand Up @@ -83,6 +92,14 @@
else:
from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_llama import FlaxLlamaForCausalLM, FlaxLlamaModel, FlaxLlamaPreTrainedModel


else:
import sys
Expand Down
Loading

0 comments on commit 75336c1

Please sign in to comment.