Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BigBird #10183

Merged
merged 88 commits into from
Mar 30, 2021
Merged

BigBird #10183

Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d0aa9ea
init bigbird
thevasudevgupta Feb 14, 2021
0e37183
model.__init__ working, conversion script ready, config updated
thevasudevgupta Feb 15, 2021
faa4ab3
add conversion script
thevasudevgupta Feb 16, 2021
facd93e
BigBirdEmbeddings working :)
thevasudevgupta Feb 16, 2021
28c8d13
slightly update conversion script
patrickvonplaten Feb 16, 2021
124b99f
BigBirdAttention working :) ; some bug in layer.output.dense
thevasudevgupta Feb 17, 2021
2c14788
add debugger-notebook
thevasudevgupta Feb 17, 2021
12a523b
forward() working for BigBirdModel :) ; replaced gelu with gelu_fast
thevasudevgupta Feb 17, 2021
aebd36b
tf code adapted to torch till rand_attn in bigbird_block_sparse_atten…
thevasudevgupta Feb 19, 2021
9df3127
BigBirdModel working in block-sparse attention mode :)
thevasudevgupta Feb 23, 2021
ad84acf
add BigBirdForPreTraining
thevasudevgupta Feb 24, 2021
4076c9b
small fix
thevasudevgupta Feb 24, 2021
78a205a
add tokenizer for BigBirdModel
thevasudevgupta Feb 25, 2021
644f65d
fix config & hence modeling
thevasudevgupta Feb 25, 2021
ce66bac
fix base prefix
thevasudevgupta Feb 25, 2021
f672205
init testing
thevasudevgupta Feb 25, 2021
372ff99
init tokenizer test
thevasudevgupta Feb 26, 2021
ed6dc49
pos_embed must be absolute, attn_type=original_full when add_cross_at…
thevasudevgupta Feb 26, 2021
7e05539
remove position_embedding_type arg
thevasudevgupta Feb 26, 2021
d257079
complete normal tests
thevasudevgupta Feb 26, 2021
07ec9a1
add comments to block sparse attention
thevasudevgupta Feb 27, 2021
01dd2e8
add attn_probs for sliding & global tokens
thevasudevgupta Feb 27, 2021
49d62e5
create fn for block sparse attn mask creation
thevasudevgupta Feb 28, 2021
5912716
add special tests
thevasudevgupta Feb 28, 2021
89de3c5
restore pos embed arg
thevasudevgupta Feb 28, 2021
b132905
minor fix
thevasudevgupta Feb 28, 2021
6ab2921
attn probs update
thevasudevgupta Mar 1, 2021
72e2532
make big bird fully gpu friendly
patrickvonplaten Mar 2, 2021
7401768
fix tests
patrickvonplaten Mar 2, 2021
da2824f
remove pruning
patrickvonplaten Mar 2, 2021
3a866e2
correct tokenzier & minor fixes
patrickvonplaten Mar 2, 2021
753ba75
update conversion script , remove norm_type
thevasudevgupta Mar 2, 2021
1e186d0
tokenizer-inference test add
thevasudevgupta Mar 2, 2021
72a150e
remove extra comments
thevasudevgupta Mar 2, 2021
24c74a9
add docs
thevasudevgupta Mar 3, 2021
79955e4
save intermediate
patrickvonplaten Mar 3, 2021
018b8fd
finish trivia_qa conversion
patrickvonplaten Mar 4, 2021
1716dea
small update to forward
thevasudevgupta Mar 4, 2021
15b7cfa
correct qa and layer
patrickvonplaten Mar 4, 2021
c300f3f
merge into master
patrickvonplaten Mar 4, 2021
2782295
better error message
patrickvonplaten Mar 4, 2021
56bd1d8
BigBird QA ready
thevasudevgupta Mar 5, 2021
ecfe137
fix rebased
thevasudevgupta Mar 5, 2021
eebd92a
add triva-qa debugger notebook
thevasudevgupta Mar 5, 2021
f6b6f43
qa setup
thevasudevgupta Mar 6, 2021
a50a10c
fixed till embeddings
thevasudevgupta Mar 7, 2021
edf5f2a
some issue in q/k/v_layer
thevasudevgupta Mar 8, 2021
a94d006
fix bug in conversion-script
thevasudevgupta Mar 9, 2021
3b489a3
fixed till self-attn
thevasudevgupta Mar 9, 2021
1e3aa50
qa fixed except layer norm
thevasudevgupta Mar 11, 2021
2f59e51
add qa end2end test
thevasudevgupta Mar 12, 2021
ef72bcd
fix gradient ckpting ; other qa test
thevasudevgupta Mar 12, 2021
8b94584
speed-up big bird a bit
patrickvonplaten Mar 15, 2021
468de78
hub_id=google
thevasudevgupta Mar 12, 2021
58ee280
clean up
thevasudevgupta Mar 15, 2021
e873658
make quality
thevasudevgupta Mar 15, 2021
4e13753
speed up einsum with bmm
patrickvonplaten Mar 16, 2021
e88110a
finish perf improvements for big bird
patrickvonplaten Mar 16, 2021
5f2d6a0
Merge branch 'master' into add_big_bird
patrickvonplaten Mar 16, 2021
cada132
Merge branch 'master' into add_big_bird
patrickvonplaten Mar 22, 2021
b8f41c0
remove wav2vec2 tok
patrickvonplaten Mar 22, 2021
22a71cc
fix tokenizer
patrickvonplaten Mar 22, 2021
5730a98
include docs
patrickvonplaten Mar 22, 2021
ab65872
correct docs
patrickvonplaten Mar 22, 2021
ff32248
add helper to auto pad block size
thevasudevgupta Mar 25, 2021
de2f812
make style
thevasudevgupta Mar 25, 2021
1b0e5f1
remove fast tokenizer for now
patrickvonplaten Mar 25, 2021
1ff2ff0
fix some
thevasudevgupta Mar 25, 2021
87a4e8c
add pad test
thevasudevgupta Mar 25, 2021
b20906c
finish
patrickvonplaten Mar 28, 2021
00cd6fb
fix some bugs
patrickvonplaten Mar 28, 2021
a719f1f
fix another bug
patrickvonplaten Mar 28, 2021
66fbec6
fix buffer tokens
thevasudevgupta Mar 29, 2021
184d361
:wqalMerge branch 'master' of https://github.com/huggingface/transfor…
patrickvonplaten Mar 29, 2021
1af7c98
Merge branch 'add_big_bird' of https://github.com/vasudevgupta7/trans…
patrickvonplaten Mar 29, 2021
aca2b4b
fix comment and merge from master
patrickvonplaten Mar 29, 2021
ef673bb
add comments
thevasudevgupta Mar 29, 2021
a6018bf
make style
patrickvonplaten Mar 29, 2021
58ef450
Merge branch 'master' of https://github.com/huggingface/transformers …
patrickvonplaten Mar 29, 2021
8a47841
commit some suggestions
thevasudevgupta Mar 29, 2021
dbc6e39
Fix typos
sgugger Mar 29, 2021
25164b9
fix some more suggestions
thevasudevgupta Mar 29, 2021
7bbbd6b
add another patch
thevasudevgupta Mar 29, 2021
ab6755e
fix copies
thevasudevgupta Mar 29, 2021
a9779b2
another path
thevasudevgupta Mar 29, 2021
df70258
update
thevasudevgupta Mar 29, 2021
0f110c5
update nit suggestions
thevasudevgupta Mar 29, 2021
8332604
make style
patrickvonplaten Mar 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions docs/source/model_doc/big_bird.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
..
Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

BigBird
-----------------------------------------------------------------------------------------------------------------------

Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The BigBird model was proposed in `<INSERT PAPER NAME HERE>
<<INSERT PAPER LINK HERE>>`__ by <INSERT AUTHORS HERE>. <INSERT SHORT SUMMARY HERE>

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*

Tips:

<INSERT TIPS ABOUT MODEL HERE>

BigBirdConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BigBirdConfig
:members:


BigBirdTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BigBirdTokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, save_vocabulary


BigBirdTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BigBirdTokenizerFast
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, save_vocabulary


BigBirdModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BigBirdModel
:members: forward


BigBirdForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BigBirdForCausalLM
:members: forward


BigBirdForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BigBirdForMaskedLM
:members: forward


BigBirdForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BigBirdForSequenceClassification
:members: forward


BigBirdForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BigBirdForMultipleChoice
:members: forward


BigBirdForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BigBirdForTokenClassification
:members: forward


BigBirdForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BigBirdForQuestionAnswering
:members: forward
34 changes: 34 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
],
"models": [],
# Models
"models.big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig", "BigBirdTokenizer"],
"models.wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config", "Wav2Vec2Tokenizer"],
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
Expand Down Expand Up @@ -277,6 +278,7 @@
# tokenziers-backed objects
if is_tokenizers_available():
# Fast tokenizers
_import_structure["models.big_bird"].append("BigBirdTokenizerFast")
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
_import_structure["models.albert"].append("AlbertTokenizerFast")
_import_structure["models.bart"].append("BartTokenizerFast")
Expand Down Expand Up @@ -364,6 +366,22 @@
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
# PyTorch models structure

_import_structure["models.big_bird"].extend(
[
"BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST",
"BigBirdForMaskedLM",
"BigBirdForCausalLM",
"BigBirdForMultipleChoice",
"BigBirdForQuestionAnswering",
"BigBirdForSequenceClassification",
"BigBirdForTokenClassification",
"BigBirdLayer",
"BigBirdModel",
"BigBirdPreTrainedModel",
"load_tf_weights_in_big_bird",
]
)

_import_structure["models.wav2vec2"].extend(
[
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1256,6 +1274,7 @@
load_tf2_weights_in_pytorch_model,
)
from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .models.big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig, BigBirdTokenizer
from .models.auto import (
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONFIG_MAPPING,
Expand Down Expand Up @@ -1403,6 +1422,7 @@
from .utils.dummy_sentencepiece_objects import *

if is_tokenizers_available():
from .models.big_bird import BigBirdTokenizerFast
from .models.albert import AlbertTokenizerFast
from .models.bart import BartTokenizerFast
from .models.barthez import BarthezTokenizerFast
Expand Down Expand Up @@ -1442,6 +1462,20 @@
# Modeling
if is_torch_available():

from .models.big_bird import (
BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST,
BigBirdForMaskedLM,
BigBirdForCausalLM,
BigBirdForMultipleChoice,
BigBirdForQuestionAnswering,
BigBirdForSequenceClassification,
BigBirdForTokenClassification,
BigBirdLayer,
BigBirdModel,
BigBirdPreTrainedModel,
load_tf_weights_in_big_bird,
)

# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# limitations under the License.

from . import (
big_bird,
albert,
auto,
bart,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ...configuration_utils import PretrainedConfig
from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from ..bert_generation.configuration_bert_generation import BertGenerationConfig
Expand Down Expand Up @@ -73,6 +74,7 @@
(key, value)
for pretrained_map in [
# Add archive maps here
BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
Expand Down Expand Up @@ -116,6 +118,7 @@
CONFIG_MAPPING = OrderedDict(
[
# Add configs here
("big_bird", BigBirdConfig),
("wav2vec2", Wav2Vec2Config),
("convbert", ConvBertConfig),
("led", LEDConfig),
Expand Down Expand Up @@ -165,6 +168,7 @@
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("big_bird", "BigBird"),
("wav2vec2", "Wav2Vec2"),
("convbert", "ConvBERT"),
("led", "LED"),
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
from ...utils import logging

# Add modeling imports here
thevasudevgupta marked this conversation as resolved.
Show resolved Hide resolved
from ..big_bird.modeling_big_bird import (
BigBirdForMaskedLM,
BigBirdForCausalLM,
BigBirdForMultipleChoice,
BigBirdForQuestionAnswering,
BigBirdForSequenceClassification,
BigBirdForTokenClassification,
BigBirdModel,
)
from ..albert.modeling_albert import (
AlbertForMaskedLM,
AlbertForMultipleChoice,
Expand Down Expand Up @@ -68,6 +77,15 @@
)

# Add modeling imports here
from ..big_bird.modeling_big_bird import (
BigBirdForMaskedLM,
BigBirdForCausalLM,
BigBirdForMultipleChoice,
BigBirdForQuestionAnswering,
BigBirdForSequenceClassification,
BigBirdForTokenClassification,
BigBirdModel,
)
from ..convbert.modeling_convbert import (
ConvBertForMaskedLM,
ConvBertForMultipleChoice,
Expand Down Expand Up @@ -243,6 +261,7 @@
XLNetModel,
)
from .configuration_auto import (
BigBirdConfig,
AlbertConfig,
AutoConfig,
BartConfig,
Expand Down Expand Up @@ -296,6 +315,7 @@
MODEL_MAPPING = OrderedDict(
[
# Base model mapping
(BigBirdConfig, BigBirdModel),
(Wav2Vec2Config, Wav2Vec2Model),
(ConvBertConfig, ConvBertModel),
(LEDConfig, LEDModel),
Expand Down Expand Up @@ -376,6 +396,7 @@
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[
# Model with LM heads mapping
(BigBirdConfig, BigBirdForMaskedLM),
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
(ConvBertConfig, ConvBertForMaskedLM),
(LEDConfig, LEDForConditionalGeneration),
Expand Down Expand Up @@ -414,6 +435,7 @@
MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Causal LM mapping
(BigBirdConfig, BigBirdForCausalLM),
(CamembertConfig, CamembertForCausalLM),
(XLMRobertaConfig, XLMRobertaForCausalLM),
(RobertaConfig, RobertaForCausalLM),
Expand Down Expand Up @@ -443,6 +465,7 @@
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
[
# Model for Masked LM mapping
(BigBirdConfig, BigBirdForMaskedLM),
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
(ConvBertConfig, ConvBertForMaskedLM),
(LayoutLMConfig, LayoutLMForMaskedLM),
Expand Down Expand Up @@ -490,6 +513,7 @@
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Sequence Classification mapping
(BigBirdConfig, BigBirdForSequenceClassification),
(ConvBertConfig, ConvBertForSequenceClassification),
(LEDConfig, LEDForSequenceClassification),
(DistilBertConfig, DistilBertForSequenceClassification),
Expand Down Expand Up @@ -523,6 +547,7 @@
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
# Model for Question Answering mapping
(BigBirdConfig, BigBirdForQuestionAnswering),
(ConvBertConfig, ConvBertForQuestionAnswering),
(LEDConfig, LEDForQuestionAnswering),
(DistilBertConfig, DistilBertForQuestionAnswering),
Expand Down Expand Up @@ -558,6 +583,7 @@
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Token Classification mapping
(BigBirdConfig, BigBirdForTokenClassification),
(ConvBertConfig, ConvBertForTokenClassification),
(LayoutLMConfig, LayoutLMForTokenClassification),
(DistilBertConfig, DistilBertForTokenClassification),
Expand All @@ -583,6 +609,7 @@
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[
# Model for Multiple Choice mapping
(BigBirdConfig, BigBirdForMultipleChoice),
(ConvBertConfig, ConvBertForMultipleChoice),
(CamembertConfig, CamembertForMultipleChoice),
(ElectraConfig, ElectraForMultipleChoice),
Expand Down
Loading