Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BLOOM Flax #18022

Closed
wants to merge 77 commits into from
Closed
Show file tree
Hide file tree
Changes from 74 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
cd19981
First commit
younesbelkada Jul 5, 2022
23c0f74
step 1 working
younesbelkada Jul 5, 2022
9d33197
add alibi
younesbelkada Jul 5, 2022
1d3cf96
placeholder for `scan`
sanchit-gandhi Jul 5, 2022
bcb1eeb
add matrix mult alibi
younesbelkada Jul 5, 2022
e969410
beta scaling factor for bmm
sanchit-gandhi Jul 5, 2022
ad537e7
working v1 - simple forward pass
younesbelkada Jul 5, 2022
b89c268
move layer_number from attribute to arg in call
sanchit-gandhi Jul 5, 2022
a14566f
partial functioning scan
sanchit-gandhi Jul 5, 2022
086e013
add more modifs
younesbelkada Jul 5, 2022
846c658
add test
younesbelkada Jul 5, 2022
ed8d023
hacky working scan
sanchit-gandhi Jul 5, 2022
2aeac52
Merge remote-tracking branch 'younesbelkada/add_bloom_flax' into add_…
sanchit-gandhi Jul 5, 2022
768784c
update scan for new kwarg order
sanchit-gandhi Jul 5, 2022
0998f31
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Jul 5, 2022
8f0a509
fix position_ids problem
patrickvonplaten Jul 5, 2022
f96979f
fix bug in attention layer
patrickvonplaten Jul 5, 2022
af86c28
small fix
younesbelkada Jul 6, 2022
2b4f22e
alibi shifting
younesbelkada Jul 6, 2022
b682321
prelim refactor
sanchit-gandhi Jul 6, 2022
3444985
finish refactor
sanchit-gandhi Jul 6, 2022
807aae5
Merge remote-tracking branch 'younesbelkada/add_bloom_flax' into add_…
sanchit-gandhi Jul 6, 2022
afd4776
incorporate dropout_add to attention module
sanchit-gandhi Jul 6, 2022
b6bb401
make style
patrickvonplaten Jul 6, 2022
d4f39d3
make padding work again
patrickvonplaten Jul 6, 2022
4bcfc1b
update
patrickvonplaten Jul 6, 2022
e9ad307
remove bogus file
patrickvonplaten Jul 6, 2022
0067b06
up
patrickvonplaten Jul 6, 2022
7de2e96
get generation to work
patrickvonplaten Jul 6, 2022
fc9c4d5
clean code a bit
patrickvonplaten Jul 6, 2022
fbd229f
added small tests
younesbelkada Jul 6, 2022
67c8325
adding albii test
younesbelkada Jul 6, 2022
9e282c2
make CI tests pass:
younesbelkada Jul 7, 2022
9043bdd
fix few nits
younesbelkada Jul 7, 2022
f77a6fe
fix nit onnx
younesbelkada Jul 7, 2022
564b400
fix onnx nit
younesbelkada Jul 7, 2022
d16ec40
add missing dtype args to nn.Modules
sanchit-gandhi Jul 7, 2022
0571fa1
remove debugging statements
sanchit-gandhi Jul 7, 2022
ccd35c3
fix scan generate
sanchit-gandhi Jul 8, 2022
400987b
Merge branch 'main' into add_bloom_flax
younesbelkada Jul 18, 2022
f78e659
Update modeling_flax_bloom.py
younesbelkada Jul 18, 2022
d0b6e2b
Update test_modeling_flax_bloom.py
younesbelkada Jul 18, 2022
266abe8
Update test_modeling_flax_bloom.py
younesbelkada Jul 18, 2022
6f01ad9
Update test_modeling_flax_bloom.py
younesbelkada Jul 18, 2022
0ab14b4
fix small test issue + make style
younesbelkada Jul 18, 2022
91a7511
clean up
sanchit-gandhi Jul 27, 2022
d2b1cd0
Update tests/models/bloom/test_modeling_flax_bloom.py
younesbelkada Jul 27, 2022
d59baf7
fix function name
younesbelkada Jul 27, 2022
07a1b4f
small fix test
younesbelkada Jul 27, 2022
0b7168d
forward contrib credits from PR17761
haileyschoelkopf Jul 28, 2022
cecfcb1
Fix failing test
younesbelkada Jul 29, 2022
74f5b07
fix small typo documentation
younesbelkada Jul 29, 2022
513cbc7
Merge branch 'main' into add_bloom_flax
younesbelkada Aug 10, 2022
c7e9574
fix non passing test
younesbelkada Aug 10, 2022
dda4a08
refactor call
younesbelkada Sep 12, 2022
822486d
Merge remote-tracking branch 'upstream2/main' into add_bloom_flax
younesbelkada Sep 12, 2022
ecfaf81
make style
younesbelkada Sep 12, 2022
c685442
upcast to fp32
younesbelkada Sep 12, 2022
48e7dac
cleaner way to upcast
younesbelkada Sep 12, 2022
1999df7
remove unused args
younesbelkada Sep 12, 2022
fc37126
remove layer number
younesbelkada Sep 12, 2022
21b68e1
fix scan test
younesbelkada Sep 12, 2022
3ecbae1
make style
younesbelkada Sep 12, 2022
05c0521
fix i4 casting
younesbelkada Sep 13, 2022
e18a2c0
Merge remote-tracking branch 'upstream2/main' into add_bloom_flax
younesbelkada Sep 13, 2022
ea7dc48
Update src/transformers/models/bloom/modeling_flax_bloom.py
younesbelkada Oct 11, 2022
0ba8cdd
fix slow test
younesbelkada Oct 11, 2022
2c51c27
Merge branch 'add_bloom_flax' of https://github.com/younesbelkada/tra…
younesbelkada Oct 11, 2022
0349f8c
Merge remote-tracking branch 'upstream/main' into add_bloom_flax
younesbelkada Oct 11, 2022
0b70bf2
remove `layer_past`
younesbelkada Oct 11, 2022
b002f64
refactor a bit
younesbelkada Oct 11, 2022
42faf0e
fix `scan` slow test
younesbelkada Oct 11, 2022
453bdc1
remove useless import
younesbelkada Oct 11, 2022
7f9d74e
major changes
younesbelkada Oct 11, 2022
38c5946
major refactoring
younesbelkada Oct 11, 2022
dcdd563
Merge remote-tracking branch 'upstream/main' into add_bloom_flax
younesbelkada Oct 12, 2022
41ab72e
Merge remote-tracking branch 'upstream/main' into add_bloom_flax
younesbelkada Oct 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ Flax), PyTorch, and/or TensorFlow.
| BigBird-Pegasus | ❌ | ❌ | ✅ | ❌ | ❌ |
| Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ |
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ |
| BLOOM | ❌ | ✅ | ✅ | ❌ | |
| BLOOM | ❌ | ✅ | ✅ | ❌ | |
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| CANINE | ✅ | ❌ | ✅ | ❌ | ❌ |
| CLIP | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand Down
10 changes: 10 additions & 0 deletions docs/source/en/model_doc/bloom.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,13 @@ Several smaller versions of the models have been trained on the same dataset. BL

[[autodoc]] BloomForQuestionAnswering
- forward

## FlaxBloomModel

[[autodoc]] FlaxBloomModel
- __call__

## FlaxBloomForCausalLM

[[autodoc]] FlaxBloomForCausalLM
- __call__
8 changes: 8 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2939,6 +2939,13 @@
"FlaxBlenderbotSmallPreTrainedModel",
]
)
_import_structure["models.bloom"].extend(
[
"FlaxBloomForCausalLM",
"FlaxBloomModel",
"FlaxBloomPreTrainedModel",
]
)
_import_structure["models.clip"].extend(
[
"FlaxCLIPModel",
Expand Down Expand Up @@ -5471,6 +5478,7 @@
FlaxBlenderbotSmallModel,
FlaxBlenderbotSmallPreTrainedModel,
)
from .models.bloom import FlaxBloomForCausalLM, FlaxBloomModel, FlaxBloomPreTrainedModel
from .models.clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
Expand Down
17 changes: 15 additions & 2 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,18 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:

def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
# numpy currently does not support bfloat16, need to go over float32 in this case to not loose precision
try:
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
is_bfloat_16 = all(v.dtype == torch.bfloat16 for v in pt_state_dict.values()) # noqa: F821
pt_state_dict = {k: v.numpy() if not is_bfloat_16 else v.float().numpy() for k, v in pt_state_dict.items()}

model_prefix = flax_model.base_model_prefix
random_flax_state_dict = flatten_dict(flax_model.params)
Expand Down Expand Up @@ -156,7 +167,9 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
)

# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
flax_state_dict[flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)

return unflatten_dict(flax_state_dict)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
("big_bird", "FlaxBigBirdModel"),
("blenderbot", "FlaxBlenderbotModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("bloom", "FlaxBloomModel"),
("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"),
Expand Down Expand Up @@ -129,6 +130,7 @@
("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"),
("bloom", "FlaxBloomForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
Expand Down
34 changes: 32 additions & 2 deletions src/transformers/models/bloom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,21 @@

from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tokenizers_available,
is_torch_available,
)


_import_structure = {
"configuration_bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", "BloomOnnxConfig"],
"configuration_bloom": [
"BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BloomConfig",
"BloomOnnxConfig",
],
}
try:
if not is_tokenizers_available():
Expand All @@ -48,6 +58,19 @@
"BloomForQuestionAnswering",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_bloom"] = [
"FlaxBloomForCausalLM",
"FlaxBloomModel",
"FlaxBloomPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig

Expand Down Expand Up @@ -75,6 +98,13 @@
BloomPreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_bloom import FlaxBloomForCausalLM, FlaxBloomModel, FlaxBloomPreTrainedModel
else:
import sys

Expand Down
Loading