diff --git a/README.md b/README.md index de2917c9a23855..30a00c8c27770e 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[BARThez](https://huggingface.co/transformers/model_doc/barthez.html)** (from École polytechnique) released with the paper [BARThez: a Skilled Pretrained French Sequence-to-Sequence Model](https://arxiv.org/abs/2010.12321) by Moussa Kamal Eddine, Antoine J.-P. Tixier, Michalis Vazirgiannis. 1. **[BERT](https://huggingface.co/transformers/model_doc/bert.html)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. 1. **[BERT For Sequence Generation](https://huggingface.co/transformers/model_doc/bertgeneration.html)** (from Google) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. +1. **[BigBird-RoBERTa](https://huggingface.co/transformers/model_doc/bigbird.html)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed. 1. **[Blenderbot](https://huggingface.co/transformers/model_doc/blenderbot.html)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. 1. **[BlenderbotSmall](https://huggingface.co/transformers/model_doc/blenderbot_small.html)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. 1. **[BORT](https://huggingface.co/transformers/model_doc/bort.html)** (from Alexa) released with the paper [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) by Adrian de Wynter and Daniel J. Perry. diff --git a/docs/source/index.rst b/docs/source/index.rst index 3e0f83e942ad03..373012c99c04fc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -97,130 +97,133 @@ and conversion utilities for the following models: 5. :doc:`BERT For Sequence Generation ` (from Google) released with the paper `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks `__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. -6. :doc:`Blenderbot ` (from Facebook) released with the paper `Recipes for building an +6. :doc:`BigBird-RoBERTa ` (from Google Research) released with the paper `Big Bird: Transformers + for Longer Sequences `__ by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua + Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed. +7. :doc:`Blenderbot ` (from Facebook) released with the paper `Recipes for building an open-domain chatbot `__ by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. -7. :doc:`BlenderbotSmall ` (from Facebook) released with the paper `Recipes for building an +8. :doc:`BlenderbotSmall ` (from Facebook) released with the paper `Recipes for building an open-domain chatbot `__ by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. -8. :doc:`BORT ` (from Alexa) released with the paper `Optimal Subarchitecture Extraction For BERT +9. :doc:`BORT ` (from Alexa) released with the paper `Optimal Subarchitecture Extraction For BERT `__ by Adrian de Wynter and Daniel J. Perry. -9. :doc:`CamemBERT ` (from Inria/Facebook/Sorbonne) released with the paper `CamemBERT: a Tasty - French Language Model `__ by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz - Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. -10. :doc:`ConvBERT ` (from YituTech) released with the paper `ConvBERT: Improving BERT with +10. :doc:`CamemBERT ` (from Inria/Facebook/Sorbonne) released with the paper `CamemBERT: a Tasty + French Language Model `__ by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz + Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. +11. :doc:`ConvBERT ` (from YituTech) released with the paper `ConvBERT: Improving BERT with Span-based Dynamic Convolution `__ by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. -11. :doc:`CTRL ` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language +12. :doc:`CTRL ` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language Model for Controllable Generation `__ by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. -12. :doc:`DeBERTa ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with +13. :doc:`DeBERTa ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -13. :doc:`DeBERTa-v2 ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT +14. :doc:`DeBERTa-v2 ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -14. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale +15. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation `__ by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. -15. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a +16. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter `__ by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into `DistilGPT2 `__, RoBERTa into `DistilRoBERTa `__, Multilingual BERT into `DistilmBERT `__ and a German version of DistilBERT. -16. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain +17. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain Question Answering `__ by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. -17. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: +18. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: Pre-training text encoders as discriminators rather than generators `__ by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. -18. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model +19. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model Pre-training for French `__ by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -19. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: +20. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing `__ by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. -20. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative +21. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative Pre-Training `__ by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. -21. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask +22. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask Learners `__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. -22. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization +23. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization `__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer -23. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training +24. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. -24. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer +25. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -25. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document +26. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -26. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +27. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -27. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual +28. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual Machine Translation `__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. -28. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +29. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -29. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +30. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -30. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +31. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -31. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +32. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -32. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +33. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -33. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +34. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -34. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +35. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -35. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +36. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -36. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +37. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -37. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper +38. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. -38. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +39. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -39. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +40. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -40. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +41. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -41. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +42. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -42. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +43. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -43. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +44. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -44. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +45. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -45. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +46. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -46. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +47. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -47. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +48. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -247,6 +250,8 @@ TensorFlow and/or Flax. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Bert Generation | ✅ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| BigBird | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | BlenderbotSmall | ✅ | ❌ | ✅ | ✅ | ❌ | @@ -407,6 +412,7 @@ TensorFlow and/or Flax. model_doc/bert model_doc/bertweet model_doc/bertgeneration + model_doc/bigbird model_doc/blenderbot model_doc/blenderbot_small model_doc/bort diff --git a/docs/source/model_doc/bigbird.rst b/docs/source/model_doc/bigbird.rst new file mode 100644 index 00000000000000..8d3936a79589d7 --- /dev/null +++ b/docs/source/model_doc/bigbird.rst @@ -0,0 +1,128 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +BigBird +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The BigBird model was proposed in `Big Bird: Transformers for Longer Sequences `__ by +Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, +Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention +based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse +attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it +has been shown that applying sparse, global, and random attention approximates full attention, while being +computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context, +BigBird has shown improved performance on various long document NLP tasks, such as question answering and +summarization, compared to BERT or RoBERTa. + +The abstract from the paper is the following: + +*Transformers-based models, such as BERT, have been one of the most successful deep learning models for NLP. +Unfortunately, one of their core limitations is the quadratic dependency (mainly in terms of memory) on the sequence +length due to their full attention mechanism. To remedy this, we propose, BigBird, a sparse attention mechanism that +reduces this quadratic dependency to linear. We show that BigBird is a universal approximator of sequence functions and +is Turing complete, thereby preserving these properties of the quadratic, full attention model. Along the way, our +theoretical analysis reveals some of the benefits of having O(1) global tokens (such as CLS), that attend to the entire +sequence as part of the sparse attention mechanism. The proposed sparse attention can handle sequences of length up to +8x of what was previously possible using similar hardware. As a consequence of the capability to handle longer context, +BigBird drastically improves performance on various NLP tasks such as question answering and summarization. We also +propose novel applications to genomics data.* + +Tips: + +- BigBird comes with 2 implementations: **original_full** & **block_sparse**. For the sequence length < 1024, using + **original_full** is advised as there is no benefit in using **block_sparse** attention. +- The code currently uses window size of 3 blocks and 2 global blocks. +- Sequence length must be divisible by block size. +- Current implementation supports only **ITC**. +- Current implementation doesn't support **num_random_blocks = 0** + +The original code can be found `here `__. + +BigBirdConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdConfig + :members: + + +BigBirdTokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdTokenizer + :members: build_inputs_with_special_tokens, get_special_tokens_mask, + create_token_type_ids_from_sequences, save_vocabulary + + +BigBird specific outputs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.models.big_bird.modeling_big_bird.BigBirdForPreTrainingOutput + :members: + + +BigBirdModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdModel + :members: forward + + +BigBirdForPreTraining +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdForPreTraining + :members: forward + + +BigBirdForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdForCausalLM + :members: forward + + +BigBirdForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdForMaskedLM + :members: forward + + +BigBirdForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdForSequenceClassification + :members: forward + + +BigBirdForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdForMultipleChoice + :members: forward + + +BigBirdForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdForTokenClassification + :members: forward + + +BigBirdForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdForQuestionAnswering + :members: forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f08f8c4b919401..1a78c5e4989b47 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -150,6 +150,7 @@ "models.bert_generation": ["BertGenerationConfig"], "models.bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"], "models.bertweet": ["BertweetTokenizer"], + "models.big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig", "BigBirdTokenizer"], "models.blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig", "BlenderbotTokenizer"], "models.blenderbot_small": [ "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -484,6 +485,22 @@ "load_tf_weights_in_bert_generation", ] ) + _import_structure["models.big_bird"].extend( + [ + "BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST", + "BigBirdForCausalLM", + "BigBirdForMaskedLM", + "BigBirdForMultipleChoice", + "BigBirdForPreTraining", + "BigBirdForQuestionAnswering", + "BigBirdForSequenceClassification", + "BigBirdForTokenClassification", + "BigBirdLayer", + "BigBirdModel", + "BigBirdPreTrainedModel", + "load_tf_weights_in_big_bird", + ] + ) _import_structure["models.blenderbot"].extend( [ "BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1376,6 +1393,7 @@ from .models.bert_generation import BertGenerationConfig from .models.bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .models.bertweet import BertweetTokenizer + from .models.big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig, BigBirdTokenizer from .models.blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig, BlenderbotTokenizer from .models.blenderbot_small import ( BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -1678,6 +1696,20 @@ BertGenerationEncoder, load_tf_weights_in_bert_generation, ) + from .models.big_bird import ( + BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST, + BigBirdForCausalLM, + BigBirdForMaskedLM, + BigBirdForMultipleChoice, + BigBirdForPreTraining, + BigBirdForQuestionAnswering, + BigBirdForSequenceClassification, + BigBirdForTokenClassification, + BigBirdLayer, + BigBirdModel, + BigBirdPreTrainedModel, + load_tf_weights_in_big_bird, + ) from .models.blenderbot import ( BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForCausalLM, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index ca371d804ca389..465612f1dff966 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -25,6 +25,7 @@ bert_generation, bert_japanese, bertweet, + big_bird, blenderbot, blenderbot_small, camembert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c28d3190dce2ce..27726b4d6ba1b6 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -22,6 +22,7 @@ from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from ..bert_generation.configuration_bert_generation import BertGenerationConfig +from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig from ..blenderbot_small.configuration_blenderbot_small import ( BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -80,6 +81,7 @@ (key, value) for pretrained_map in [ # Add archive maps here + BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -127,6 +129,7 @@ CONFIG_MAPPING = OrderedDict( [ # Add configs here + ("big_bird", BigBirdConfig), ("speech_to_text", Speech2TextConfig), ("wav2vec2", Wav2Vec2Config), ("m2m_100", M2M100Config), @@ -180,6 +183,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("big_bird", "BigBird"), ("speech_to_text", "Speech2Text"), ("wav2vec2", "Wav2Vec2"), ("m2m_100", "M2M100"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4d11dbaa37b65f..be57b7ea22075f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -51,6 +51,16 @@ BertModel, ) from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder +from ..big_bird.modeling_big_bird import ( + BigBirdForCausalLM, + BigBirdForMaskedLM, + BigBirdForMultipleChoice, + BigBirdForPreTraining, + BigBirdForQuestionAnswering, + BigBirdForSequenceClassification, + BigBirdForTokenClassification, + BigBirdModel, +) from ..blenderbot.modeling_blenderbot import BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel from ..blenderbot_small.modeling_blenderbot_small import ( BlenderbotSmallForCausalLM, @@ -263,6 +273,7 @@ BartConfig, BertConfig, BertGenerationConfig, + BigBirdConfig, BlenderbotConfig, BlenderbotSmallConfig, CamembertConfig, @@ -315,6 +326,7 @@ MODEL_MAPPING = OrderedDict( [ # Base model mapping + (BigBirdConfig, BigBirdModel), (Speech2TextConfig, Speech2TextModel), (Wav2Vec2Config, Wav2Vec2Model), (M2M100Config, M2M100Model), @@ -380,6 +392,7 @@ (RobertaConfig, RobertaForMaskedLM), (SqueezeBertConfig, SqueezeBertForMaskedLM), (BertConfig, BertForPreTraining), + (BigBirdConfig, BigBirdForPreTraining), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), (GPT2Config, GPT2LMHeadModel), (MobileBertConfig, MobileBertForPreTraining), @@ -402,6 +415,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( [ # Model with LM heads mapping + (BigBirdConfig, BigBirdForMaskedLM), (Speech2TextConfig, Speech2TextForConditionalGeneration), (Wav2Vec2Config, Wav2Vec2ForMaskedLM), (M2M100Config, M2M100ForConditionalGeneration), @@ -444,6 +458,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Causal LM mapping + (BigBirdConfig, BigBirdForCausalLM), (CamembertConfig, CamembertForCausalLM), (XLMRobertaConfig, XLMRobertaForCausalLM), (RobertaConfig, RobertaForCausalLM), @@ -473,6 +488,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( [ # Model for Masked LM mapping + (BigBirdConfig, BigBirdForMaskedLM), (Wav2Vec2Config, Wav2Vec2ForMaskedLM), (ConvBertConfig, ConvBertForMaskedLM), (LayoutLMConfig, LayoutLMForMaskedLM), @@ -523,6 +539,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Sequence Classification mapping + (BigBirdConfig, BigBirdForSequenceClassification), (ConvBertConfig, ConvBertForSequenceClassification), (LEDConfig, LEDForSequenceClassification), (DistilBertConfig, DistilBertForSequenceClassification), @@ -558,6 +575,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( [ # Model for Question Answering mapping + (BigBirdConfig, BigBirdForQuestionAnswering), (ConvBertConfig, ConvBertForQuestionAnswering), (LEDConfig, LEDForQuestionAnswering), (DistilBertConfig, DistilBertForQuestionAnswering), @@ -595,6 +613,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Token Classification mapping + (BigBirdConfig, BigBirdForTokenClassification), (ConvBertConfig, ConvBertForTokenClassification), (LayoutLMConfig, LayoutLMForTokenClassification), (DistilBertConfig, DistilBertForTokenClassification), @@ -622,6 +641,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( [ # Model for Multiple Choice mapping + (BigBirdConfig, BigBirdForMultipleChoice), (ConvBertConfig, ConvBertForMultipleChoice), (CamembertConfig, CamembertForMultipleChoice), (ElectraConfig, ElectraForMultipleChoice), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 346e626459199f..4466fce871fd9b 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -60,6 +60,7 @@ BartConfig, BertConfig, BertGenerationConfig, + BigBirdConfig, BlenderbotConfig, BlenderbotSmallConfig, CamembertConfig, @@ -111,6 +112,7 @@ from ..albert.tokenization_albert import AlbertTokenizer from ..barthez.tokenization_barthez import BarthezTokenizer from ..bert_generation.tokenization_bert_generation import BertGenerationTokenizer + from ..big_bird.tokenization_big_bird import BigBirdTokenizer from ..camembert.tokenization_camembert import CamembertTokenizer from ..deberta_v2.tokenization_deberta_v2 import DebertaV2Tokenizer from ..m2m_100 import M2M100Tokenizer @@ -129,6 +131,7 @@ AlbertTokenizer = None BarthezTokenizer = None BertGenerationTokenizer = None + BigBirdTokenizer = None CamembertTokenizer = None DebertaV2Tokenizer = None MarianTokenizer = None @@ -258,6 +261,7 @@ (TapasConfig, (TapasTokenizer, None)), (LEDConfig, (LEDTokenizer, LEDTokenizerFast)), (ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)), + (BigBirdConfig, (BigBirdTokenizer, None)), (IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)), (Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)), ] diff --git a/src/transformers/models/big_bird/__init__.py b/src/transformers/models/big_bird/__init__.py new file mode 100644 index 00000000000000..21aa3e927f8e87 --- /dev/null +++ b/src/transformers/models/big_bird/__init__.py @@ -0,0 +1,82 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule, is_torch_available + + +_import_structure = { + "configuration_big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig"], + "tokenization_big_bird": ["BigBirdTokenizer"], +} + +if is_torch_available(): + _import_structure["modeling_big_bird"] = [ + "BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST", + "BigBirdForCausalLM", + "BigBirdForMaskedLM", + "BigBirdForMultipleChoice", + "BigBirdForPreTraining", + "BigBirdForQuestionAnswering", + "BigBirdForSequenceClassification", + "BigBirdForTokenClassification", + "BigBirdLayer", + "BigBirdModel", + "BigBirdPreTrainedModel", + "load_tf_weights_in_big_bird", + ] + + +if TYPE_CHECKING: + from .configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig + from .tokenization_big_bird import BigBirdTokenizer + + if is_torch_available(): + from .modeling_big_bird import ( + BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST, + BigBirdForCausalLM, + BigBirdForMaskedLM, + BigBirdForMultipleChoice, + BigBirdForPreTraining, + BigBirdForQuestionAnswering, + BigBirdForSequenceClassification, + BigBirdForTokenClassification, + BigBirdLayer, + BigBirdModel, + BigBirdPreTrainedModel, + load_tf_weights_in_big_bird, + ) + + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/big_bird/configuration_big_bird.py b/src/transformers/models/big_bird/configuration_big_bird.py new file mode 100644 index 00000000000000..6ac9c4b951066e --- /dev/null +++ b/src/transformers/models/big_bird/configuration_big_bird.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BigBird model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/bigbird-roberta-base": "https://huggingface.co/google/bigbird-roberta-base/resolve/main/config.json", + "google/bigbird-roberta-large": "https://huggingface.co/google/bigbird-roberta-large/resolve/main/config.json", + "google/bigbird-base-trivia-itc": "https://huggingface.co/google/bigbird-base-trivia-itc/resolve/main/config.json", + # See all BigBird models at https://huggingface.co/models?filter=big_bird +} + + +class BigBirdConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.BigBirdModel`. It is used to + instantiate an BigBird model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the BigBird + `google/bigbird-roberta-base `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50358): + Vocabulary size of the BigBird model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.BigBirdModel`. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_fast"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"gelu_fast"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 4096): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 1024 or 2048 or 4096). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BigBirdModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``config.is_decoder=True``. + attention_type (:obj:`str`, `optional`, defaults to :obj:`"block_sparse"`) + Whether to use block sparse attention (with n complexity) as introduced in paper or original attention + layer (with n^2 complexity). Possible values are :obj:`"original_full"` and :obj:`"block_sparse"`. + use_bias (:obj:`bool`, `optional`, defaults to :obj:`True`) + Whether to use bias in query, key, value. + rescale_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`) + Whether to rescale embeddings with (hidden_size ** 0.5). + block_size (:obj:`int`, `optional`, defaults to 64) + Size of each block. Useful only when :obj:`attention_type == "block_sparse"`. + num_random_blocks (:obj:`int`, `optional`, defaults to 3) + Each query is going to attend these many number of random blocks. Useful only when :obj:`attention_type == + "block_sparse"`. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example:: + + >>> from transformers import BigBirdModel, BigBirdConfig + + >>> # Initializing a BigBird google/bigbird-roberta-base style configuration + >>> configuration = BigBirdConfig() + + >>> # Initializing a model from the google/bigbird-roberta-base style configuration + >>> model = BigBirdModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "big_bird" + + def __init__( + self, + vocab_size=50358, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu_fast", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=4096, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + use_cache=True, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + sep_token_id=66, + attention_type="block_sparse", + use_bias=True, + rescale_embeddings=False, + block_size=64, + num_random_blocks=3, + gradient_checkpointing=False, + **kwargs + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + sep_token_id=sep_token_id, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.is_encoder_decoder = is_encoder_decoder + self.gradient_checkpointing = gradient_checkpointing + + self.rescale_embeddings = rescale_embeddings + self.attention_type = attention_type + self.use_bias = use_bias + self.block_size = block_size + self.num_random_blocks = num_random_blocks diff --git a/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 00000000000000..7cea701acd8f71 --- /dev/null +++ b/src/transformers/models/big_bird/convert_bigbird_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert BigBird checkpoint.""" + + +import argparse + +from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa): + # Initialise PyTorch model + config = BigBirdConfig.from_json_file(big_bird_config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + + if is_trivia_qa: + model = BigBirdForQuestionAnswering(config) + else: + model = BigBirdForPreTraining(config) + + # Load weights from tf checkpoint + load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--big_bird_config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch( + args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa + ) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py new file mode 100755 index 00000000000000..63b61e19480b76 --- /dev/null +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -0,0 +1,2976 @@ +# coding=utf-8 +# Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch BigBird model. """ + + +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel, SequenceSummary, apply_chunking_to_forward +from ...utils import logging +from .configuration_big_bird import BigBirdConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/bigbird-roberta-base" +_CONFIG_FOR_DOC = "BigBirdConfig" +_TOKENIZER_FOR_DOC = "BigBirdTokenizer" + +BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/bigbird-roberta-base", + "google/bigbird-roberta-large", + "google/bigbird-base-trivia-itc", + # See all BigBird models at https://huggingface.co/models?filter=big_bird +] + +_TRIVIA_QA_MAPPING = { + "big_bird_attention": "attention/self", + "output_layer_norm": "output/LayerNorm", + "attention_output": "attention/output/dense", + "output": "output/dense", + "self_attention_layer_norm": "attention/output/LayerNorm", + "intermediate": "intermediate/dense", + "word_embeddings": "bert/embeddings/word_embeddings", + "position_embedding": "bert/embeddings/position_embeddings", + "type_embeddings": "bert/embeddings/token_type_embeddings", + "embeddings": "bert/embeddings", + "layer_normalization": "output/LayerNorm", + "layer_norm": "LayerNorm", + "trivia_qa_head": "qa_classifier", + "dense": "intermediate/dense", + "dense_1": "qa_outputs", +} + + +def load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=False): + """Load tf checkpoints in a pytorch model.""" + + def load_tf_weights_bert(init_vars, tf_path): + names = [] + tf_weights = {} + + for name, shape in init_vars: + array = tf.train.load_variable(tf_path, name) + name = name.replace("bert/encoder/LayerNorm", "bert/embeddings/LayerNorm") + logger.info(f"Loading TF weight {name} with shape {shape}") + names.append(name) + tf_weights[name] = array + + return names, tf_weights + + def load_tf_weights_trivia_qa(init_vars): + names = [] + tf_weights = {} + + for i, var in enumerate(init_vars): + name_items = var.name.split("/") + + if "transformer_scaffold" in name_items[0]: + layer_name_items = name_items[0].split("_") + if len(layer_name_items) < 3: + layer_name_items += [0] + + name_items[0] = f"bert/encoder/layer_{layer_name_items[2]}" + + name = "/".join([_TRIVIA_QA_MAPPING[x] if x in _TRIVIA_QA_MAPPING else x for x in name_items])[ + :-2 + ] # remove last :0 in variable + + if "self/attention/output" in name: + name = name.replace("self/attention/output", "output") + + if i >= len(init_vars) - 2: + name = name.replace("intermediate", "output") + + logger.info("Loading TF weight {} with shape {}".format(name, var.shape)) + array = var.value().numpy() + names.append(name) + tf_weights[name] = array + + return names, tf_weights + + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + + # Load weights from TF model + init_vars = tf.saved_model.load(tf_path).variables if is_trivia_qa else tf.train.list_variables(tf_path) + + assert len(init_vars) > 0, "Loaded trained variables cannot be empty." + + pt_names = list(model.state_dict().keys()) + + if is_trivia_qa: + names, tf_weights = load_tf_weights_trivia_qa(init_vars) + else: + names, tf_weights = load_tf_weights_bert(init_vars, tf_path) + + for txt_name in names: + array = tf_weights[txt_name] + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + pt_name = [] + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + pt_name.append("weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + pt_name.append("bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + pt_name.append("weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + pt_name.append("classifier") + elif scope_names[0] == "transform": + pointer = getattr(pointer, "transform") + pt_name.append("transform") + if ("bias" in name) or ("kernel" in name): + pointer = getattr(pointer, "dense") + pt_name.append("dense") + elif ("beta" in name) or ("gamma" in name): + pointer = getattr(pointer, "LayerNorm") + pt_name.append("LayerNorm") + else: + try: + pointer = getattr(pointer, scope_names[0]) + pt_name.append(f"{scope_names[0]}") + except AttributeError: + logger.info(f"Skipping {m_name}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + pt_name.append(f"{num}") + if m_name[-11:] == "_embeddings" or m_name == "embeddings": + pointer = getattr(pointer, "weight") + pt_name.append("weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if len(array.shape) > len(pointer.shape) and math.prod(array.shape) == math.prod(pointer.shape): + # print(txt_name, array.shape) + if ( + txt_name.endswith("attention/self/key/kernel") + or txt_name.endswith("attention/self/query/kernel") + or txt_name.endswith("attention/self/value/kernel") + ): + array = array.transpose(1, 0, 2).reshape(pointer.shape) + elif txt_name.endswith("attention/output/dense/kernel"): + array = array.transpose(0, 2, 1).reshape(pointer.shape) + else: + array = array.reshape(pointer.shape) + + if pointer.shape != array.shape: + raise ValueError( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched of {txt_name}." + ) + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + pt_weight_name = ".".join(pt_name) + logger.info(f"Initialize PyTorch weight {pt_weight_name} from {txt_name}.") + pointer.data = torch.from_numpy(array) + tf_weights.pop(txt_name, None) + pt_names.remove(pt_weight_name) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + logger.info(f"Weights not initialized in PyTorch model: {', '.join(pt_names)}.") + return model + + +class BigBirdEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + # End copy + + self.rescale_embeddings = config.rescale_embeddings + self.hidden_size = config.hidden_size + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.rescale_embeddings: + inputs_embeds = inputs_embeds * (self.hidden_size ** 0.5) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + embeddings = self.dropout(embeddings) + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BigBirdSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BigBirdModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = F.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class BigBirdBlockSparseAttention(nn.Module): + def __init__(self, config, seed=None): + super().__init__() + + self.max_seqlen = config.max_position_embeddings + self.seed = seed + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.num_random_blocks = config.num_random_blocks + self.block_size = config.block_size + + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + output_attentions=None, + ): + # Currently this `class` can't be used in decoder. + + batch_size, seqlen, _ = hidden_states.size() + to_seq_length = from_seq_length = seqlen + from_block_size = to_block_size = self.block_size + + assert from_seq_length % from_block_size == 0, "Query sided sequence length must be multiple of block size" + assert to_seq_length % to_block_size == 0, "Key/Value sided sequence length must be multiple of block size" + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + context_layer, attention_probs = self.bigbird_block_sparse_attention( + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + self.num_attention_heads, + self.num_random_blocks, + self.attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_length, + to_seq_length, + seed=self.seed, + plan_from_length=None, + plan_num_rand_blocks=None, + output_attentions=output_attentions, + ) + + context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + @staticmethod + def torch_bmm_nd(inp_1, inp_2, ndim=None): + """ Fast nd matrix multiplication """ + # faster replacement of torch.einsum ("bhqk,bhkd->bhqd") + return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view( + inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1]) + ) + + @staticmethod + def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None): + """ Fast nd matrix multiplication with transpose """ + # faster replacement of torch.einsum (bhqd,bhkd->bhqk) + return torch.bmm( + inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2) + ).view(inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 2])) + + def bigbird_block_sparse_attention( + self, + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + n_heads, + n_rand_blocks, + attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_len, + to_seq_len, + seed, + plan_from_length, + plan_num_rand_blocks, + output_attentions, + ): + + # BigBird block-sparse attention as suggested in paper + + # ITC: + # global tokens: 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # ETC: + # global tokens: extra_globals_tokens + 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # Note: + # 1) Currently, ETC is not supported. + # 2) Window size is fixed to 3 blocks & it can be changed only by + # changing `block_size`. + # 3) Number of global blocks are fixed (2 blocks here) & global tokens can be + # controlled only by `block_size`. + + # attention is calculated separately for q[0], q[1], q[2:-2], q[-2], q[-1] in order to use special trick of shifting tokens (for calculating sliding attention) + # hence following code can be divided into 5 parts. + + if from_seq_len // from_block_size != to_seq_len // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + rsqrt_d = 1 / math.sqrt(attention_head_size) + bsz = batch_size + + # generate random attention and corresponding masks + np.random.seed(seed) + if from_seq_len in [1024, 3072, 4096]: # old plans used in paper + rand_attn = [ + self._bigbird_block_rand_mask( + self.max_seqlen, self.max_seqlen, from_block_size, to_block_size, n_rand_blocks, last_idx=1024 + )[: (from_seq_len // from_block_size - 2)] + for _ in range(n_heads) + ] + else: + if plan_from_length is None: + plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan( + from_seq_len, from_block_size, n_rand_blocks + ) + + rand_attn = self._bigbird_block_rand_mask_with_head( + from_seq_length=from_seq_len, + to_seq_length=to_seq_len, + from_block_size=from_block_size, + to_block_size=to_block_size, + num_heads=n_heads, + plan_from_length=plan_from_length, + plan_num_rand_blocks=plan_num_rand_blocks, + ) + + rand_attn = np.stack(rand_attn, axis=0) + rand_attn = torch.tensor(rand_attn, device=query_layer.device, dtype=torch.long) + rand_attn.unsqueeze_(0) + rand_attn = torch.cat([rand_attn for _ in range(batch_size)], dim=0) + + rand_mask = self._create_rand_mask_from_inputs( + from_blocked_mask, to_blocked_mask, rand_attn, n_heads, n_rand_blocks, bsz, from_seq_len, from_block_size + ) + + blocked_query_matrix = query_layer.view(bsz, n_heads, from_seq_len // from_block_size, from_block_size, -1) + blocked_key_matrix = key_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + blocked_value_matrix = value_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + + # preparing block for randn attn + gathered_key = self.torch_gather_b2(blocked_key_matrix, rand_attn) + gathered_key = gathered_key.view( + bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 + ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] + gathered_value = self.torch_gather_b2(blocked_value_matrix, rand_attn) + gathered_value = gathered_value.view( + bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 + ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] + + # 1st PART + # 1st block (global block) attention scores + # q[0] x (k[0], k[1], k[2], k[3], k[4] .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + first_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 0], key_layer, ndim=4) + + first_product = first_product * rsqrt_d + first_product += (1.0 - to_mask) * -10000.0 + first_attn_weights = F.softmax(first_product, dim=-1) # [bsz, n_heads, from_block_size, to_seq_len] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4) + first_context_layer.unsqueeze_(2) + + # 2nd PART + # 2nd block attention scores + # q[1] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> 2nd, 3rd blocks + # global key blocks -> 1st block + + second_key_mat = torch.cat( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, 1], + blocked_key_matrix[:, :, 2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, 0], + ], + dim=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + second_value_mat = torch.cat( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, 1], + blocked_value_matrix[:, :, 2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, 0], + ], + dim=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 1], second_key_mat, ndim=4) + second_seq_pad = torch.cat( + [ + to_mask[:, :, :, : 3 * to_block_size], + to_mask[:, :, :, -to_block_size:], + first_context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), + ], + dim=3, + ) + second_rand_pad = torch.cat( + [ + first_context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), + rand_mask[:, :, 0], + ], + dim=3, + ) + second_product = second_product * rsqrt_d + second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * -10000.0 + second_attn_weights = F.softmax( + second_product, dim=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] + second_context_layer = self.torch_bmm_nd(second_attn_weights, second_value_mat, ndim=4) + + second_context_layer.unsqueeze_(2) + + # 3rd PART + # Middle blocks attention scores + # q[-2:2] x (sliding_keys, random_keys, global_keys) + # sliding attn is calculated using special trick of shifting tokens as discussed in paper + # random keys are generated by taking random indices as per `rand_attn` + # global keys -> 1st & last block + + exp_blocked_key_matrix = torch.cat( + [blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2], blocked_key_matrix[:, :, 3:-1]], dim=3 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + exp_blocked_value_matrix = torch.cat( + [blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2], blocked_value_matrix[:, :, 3:-1]], + dim=3, + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + middle_query_matrix = blocked_query_matrix[:, :, 2:-2] + + # sliding attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [b, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + inner_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, exp_blocked_key_matrix, ndim=5) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, 3*to_block_size] + inner_band_product = inner_band_product * rsqrt_d + + # randn attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + rand_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, gathered_key[:, :, 1:-1], ndim=5) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] + rand_band_product = rand_band_product * rsqrt_d + + # Including 1st block (since it's global) + first_band_product = torch.einsum( + "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, 0] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + first_band_product = first_band_product * rsqrt_d + + # Including last block (since it's global) + last_band_product = torch.einsum( + "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, -1] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + last_band_product = last_band_product * rsqrt_d + + # masking padded tokens + inner_band_product += (1.0 - band_mask) * -10000.0 + first_band_product += (1.0 - to_mask[:, :, :, :to_block_size].unsqueeze(3)) * -10000.0 + last_band_product += (1.0 - to_mask[:, :, :, -to_block_size:].unsqueeze(3)) * -10000.0 + rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * -10000.0 + + # completing attention scores matrix for all q[-2:2] + band_product = torch.cat( + [first_band_product, inner_band_product, rand_band_product, last_band_product], dim=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # safely doing softmax since attention matrix is completed + attn_weights = F.softmax( + band_product, dim=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # contibution of sliding keys + # [bsz, n_heads, m//from_block_size-4, from_block_size, 3*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + context_layer = self.torch_bmm_nd( + attn_weights[:, :, :, :, to_block_size : 4 * to_block_size], exp_blocked_value_matrix, ndim=5 + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of random keys + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + context_layer += self.torch_bmm_nd( + attn_weights[:, :, :, :, 4 * to_block_size : -to_block_size], gathered_value[:, :, 1:-1], ndim=5 + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of global keys + context_layer += torch.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, :to_block_size], blocked_value_matrix[:, :, 0] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + context_layer += torch.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, -to_block_size:], blocked_value_matrix[:, :, -1] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # 4th PART + # last 2nd token attention scores + # q[-2] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> last 3 blocks + # global key block -> 1st block + # random key block -> based on indices stored in `randn_attn` + + second_last_key_mat = torch.cat( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, -3], + blocked_key_matrix[:, :, -2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, -1], + ], + dim=2, + ) # [bsz, n_heads, (4+n_random_blocks)*to_block_size, -1] + second_last_value_mat = torch.cat( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, -3], + blocked_value_matrix[:, :, -2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, -1], + ], + dim=2, + ) # [bsz, n_heads, (4+r)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -2], second_last_key_mat, ndim=4) + second_last_seq_pad = torch.cat( + [ + to_mask[:, :, :, :to_block_size], + to_mask[:, :, :, -3 * to_block_size :], + context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), + ], + dim=3, + ) + second_last_rand_pad = torch.cat( + [ + context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), + rand_mask[:, :, -1], + ], + dim=3, + ) + second_last_product = second_last_product * rsqrt_d + second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0 + second_last_attn_weights = F.softmax( + second_last_product, dim=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] + second_last_context_layer = self.torch_bmm_nd(second_last_attn_weights, second_last_value_mat, ndim=4) + second_last_context_layer.unsqueeze_(2) + + # 5th PART + # last block (global) attention scores + # q[-1] x (k[0], k[1], k[2], k[3], .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4) + last_product = last_product * rsqrt_d + last_product += (1.0 - to_mask) * -10000.0 + last_attn_weights = F.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4) + last_context_layer.unsqueeze_(2) + + # combining representations of all tokens + context_layer = torch.cat( + [first_context_layer, second_context_layer, context_layer, second_last_context_layer, last_context_layer], + dim=2, + ) + context_layer = context_layer.view((bsz, n_heads, from_seq_len, -1)) * from_mask + context_layer = torch.transpose(context_layer, 1, 2) + + # this is just for visualizing; forward pass doesn't depend on following code + if output_attentions: + # TODO(PVP): need to verify if below code is correct + attention_probs = torch.zeros( + bsz, n_heads, from_seq_len, to_seq_len, dtype=torch.float, device=context_layer.device + ) + + # 1st query block + # corresponding to `first_context_layer` + attention_probs[:, :, :from_block_size, :] = first_attn_weights # all keys global + + # 2nd query block + # corresponding to `second_context_layer` + attention_probs[:, :, from_block_size : 2 * from_block_size, : 3 * to_block_size] = second_attn_weights[ + :, :, :, : 3 * to_block_size + ] # 1st three key blocks (global + sliding) + attention_probs[:, :, from_block_size : 2 * from_block_size, -to_block_size:] = second_attn_weights[ + :, :, :, 3 * to_block_size : 4 * to_block_size + ] # last key block (global) + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, second_attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[:, 4 * to_block_size :] + attn_probs_view[p1, p2, 1, :, i2[0]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # Middle query blocks + # corresponding to `context_layer` + # sliding keys + for q_idx in range(from_seq_len // from_block_size - 4): + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + )[:, :, 2:-2, :, 1:-1, :] + right_slice = attn_weights[:, :, q_idx, :, to_block_size : 4 * to_block_size] + attn_probs_view[:, :, q_idx, :, q_idx : q_idx + 3, :] = right_slice.view( + bsz, n_heads, from_block_size, 3, to_block_size + ) # inner_band_product + # global keys (correspomding to 1st key block) + attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, :to_block_size] = attn_weights[ + :, :, :, :, :to_block_size + ].view( + bsz, n_heads, -1, to_block_size + ) # first_band_product + # global keys (corresponding to last key block) + attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, -to_block_size:] = attn_weights[ + :, :, :, :, -to_block_size: + ].view( + bsz, n_heads, -1, to_block_size + ) # last_band_product + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + for q_idx in range(1, len(i2) - 1): + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[q_idx - 1, :, 4 * to_block_size : -to_block_size] + attn_probs_view[p1, p2, q_idx + 1, :, i2[q_idx]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # Second-last query block + # corresponding to `second_last_context_layer` + attention_probs[:, :, -2 * from_block_size : -from_block_size, :to_block_size] = second_last_attn_weights[ + :, :, :, :to_block_size + ] # 1st key block (global) + attention_probs[ + :, :, -2 * from_block_size : -from_block_size, -3 * to_block_size : + ] = second_last_attn_weights[ + :, :, :, to_block_size : 4 * to_block_size + ] # last three blocks (global + sliding) + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, second_last_attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[:, 4 * to_block_size :] + attn_probs_view[p1, p2, -2, :, i2[-1]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # last query block + # corresponding to `last_context_layer` + attention_probs[:, :, -from_block_size:, :] = last_attn_weights # all keys global + + else: + attention_probs = None + + return context_layer, attention_probs + + @staticmethod + def torch_gather_b2(params, indices): + # this operation is equilvalent to tf.gather when batch_dims=2 + + if params.shape[:2] != indices.shape[:2]: + raise ValueError( + f"Make sure that the first two dimensions of params and indices are identical, \ + but they are params: {params.shape[:2]} vs. indices: {params.shape[:2]}" + ) + num_indices_to_gather = indices.shape[-2] * indices.shape[-1] + num_indices_to_pick_from = params.shape[2] + + indices_shift = ( + torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device) + // num_indices_to_gather + * num_indices_to_pick_from + ) + + flattened_indices = indices.view(-1) + indices_shift + flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1]) + + out_flattened = flattened_params.index_select(0, flattened_indices) + + out = out_flattened.reshape(params.shape[:2] + (num_indices_to_gather,) + params.shape[3:]) + return out + + @staticmethod + def _create_rand_mask_from_inputs( + from_blocked_mask, + to_blocked_mask, + rand_attn, + num_attention_heads, + num_rand_blocks, + batch_size, + from_seq_length, + from_block_size, + ): + """ + Create 3D attention mask from a 2D tensor mask. + + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + rand_attn: [batch_size, num_attention_heads, + from_seq_length//from_block_size-2, num_rand_blocks] + num_attention_heads: int. Number of attention heads. + num_rand_blocks: int. Number of random chunks per row. + batch_size: int. Batch size for computation. + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + + Returns: + float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2, + from_block_size, num_rand_blocks*to_block_size]. + """ + num_windows = from_seq_length // from_block_size - 2 + rand_mask = torch.stack([p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]) + rand_mask = rand_mask.view(batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size) + rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask) + return rand_mask + + @staticmethod + def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks): + """ + Gives the plan of where to put random attention. + + Args: + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + num_rand_blocks: int. Number of random chunks per row. + + Returns: + plan_from_length: ending location of from block plan_num_rand_blocks: number of random ending location for + each block + """ + + plan_from_length = [] + plan_num_rand_blocks = [] + if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(0) + elif (num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks // 2) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2)) + else: + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks) + + return plan_from_length, plan_num_rand_blocks + + @staticmethod + def _bigbird_block_rand_mask( + from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1 + ): + """ + Create adjacency list of random attention. + + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_rand_blocks: int. Number of random chunks per row. + last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, + if positive then num_rand_blocks blocks choosen only upto last_idx. + + Returns: + adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks + """ + # using this method when from_seq_length in [1024, 3072, 4096] + + assert ( + from_seq_length // from_block_size == to_seq_length // to_block_size + ), "Error the number of blocks needs to be same!" + + rand_attn = np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32) + middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32) + last = to_seq_length // to_block_size - 1 + if last_idx > (2 * to_block_size): + last = (last_idx // to_block_size) - 1 + + r = num_rand_blocks # shorthand + for i in range(1, from_seq_length // from_block_size - 1): + start = i - 2 + end = i + if i == 1: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r] + elif i == 2: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r] + elif i == from_seq_length // from_block_size - 3: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] + # Missing -3: should have been sliced till last-3 + elif i == from_seq_length // from_block_size - 2: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] + # Missing -4: should have been sliced till last-4 + else: + if start > last: + start = last + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] + elif (end + 1) == last: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] + else: + rand_attn[i - 1, :] = np.random.permutation( + np.concatenate((middle_seq[:start], middle_seq[end + 1 : last])) + )[:r] + return rand_attn + + def _bigbird_block_rand_mask_with_head( + self, + from_seq_length, + to_seq_length, + from_block_size, + to_block_size, + num_heads, + plan_from_length, + plan_num_rand_blocks, + window_block_left=1, + window_block_right=1, + global_block_top=1, + global_block_bottom=1, + global_block_left=1, + global_block_right=1, + ): + """ + Create adjacency list of random attention. + + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_heads: int. total number of heads. + plan_from_length: list. plan from length where num_random_blocks are choosen from. + plan_num_rand_blocks: list. number of rand blocks within the plan. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_top: int. number of blocks at the top. + global_block_bottom: int. number of blocks at the bottom. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + + Returns: + adjacency list of size num_head where each element is of size from_seq_length//from_block_size-2 by + num_rand_blocks + """ + # using this method when from_seq_length not in [1024, 3072, 4096] + + assert ( + from_seq_length // from_block_size == to_seq_length // to_block_size + ), "Error the number of blocks needs to be same!" + + assert from_seq_length in plan_from_length, "Error from sequence length not in plan!" + + # Total number of blocks in the mmask + num_blocks = from_seq_length // from_block_size + # Number of blocks per plan + plan_block_length = np.array(plan_from_length) // from_block_size + # till when to follow plan + max_plan_idx = plan_from_length.index(from_seq_length) + # Random Attention adjajency list + rand_attn = [ + np.zeros((num_blocks, np.sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=np.int32) + for i in range(num_heads) + ] + + # We will go iteratively over the plan blocks and pick random number of + # Attention blocks from the legally allowed blocks + for plan_idx in range(max_plan_idx + 1): + rnd_r_cnt = 0 + if plan_idx > 0: + # set the row for all from_blocks starting from 0 to + # plan_block_length[plan_idx-1] + # column indx start fromm plan_block_length[plan_idx-1] and ends at + # plan_block_length[plan_idx] + if plan_num_rand_blocks[plan_idx] > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) + for blk_rw_idx in range(global_block_top, plan_block_length[plan_idx - 1]): + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=plan_block_length[plan_idx - 1], + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + for pl_id in range(plan_idx): + if plan_num_rand_blocks[pl_id] == 0: + continue + for blk_rw_idx in range(plan_block_length[plan_idx - 1], plan_block_length[plan_idx]): + rnd_r_cnt = 0 + to_start_block_id = 0 + if pl_id > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id])) + to_start_block_id = plan_block_length[pl_id - 1] + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: pl_id + 1])) + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[pl_id], + num_rand_blocks=plan_num_rand_blocks[pl_id], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + if plan_num_rand_blocks[plan_idx] == 0: + continue + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) + from_start_block_id = global_block_top + to_start_block_id = 0 + if plan_idx > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) + from_start_block_id = plan_block_length[plan_idx - 1] + to_start_block_id = plan_block_length[plan_idx - 1] + + for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]): + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + for nh in range(num_heads): + rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] + + return rand_attn + + @staticmethod + def _get_single_block_row_attention( + block_id, + to_start_block_id, + to_end_block_id, + num_rand_blocks, + window_block_left=1, + window_block_right=1, + global_block_left=1, + global_block_right=1, + ): + """ + For a single row block get random row attention. + + Args: + block_id: int. block id of row. + to_start_block_id: int. random attention coloum start id. + to_end_block_id: int. random attention coloum end id. + num_rand_blocks: int. number of random blocks to be selected. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + + Returns: + row containing the random attention vector of size num_rand_blocks. + """ + # list of to_blocks from which to choose random attention + to_block_list = np.arange(to_start_block_id, to_end_block_id, dtype=np.int32) + # permute the blocks + perm_block = np.random.permutation(to_block_list) + + # illegal blocks for the current block id, using window + illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1)) + + # Add blocks at the start and at the end + illegal_blocks.extend(list(range(global_block_left))) + illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id))) + + # The second from_block cannot choose random attention on second last to_block + if block_id == 1: + illegal_blocks.append(to_end_block_id - 2) + + # The second last from_block cannot choose random attention on second to_block + if block_id == to_end_block_id - 2: + illegal_blocks.append(1) + + selected_random_blokcs = [] + + for i in range(to_end_block_id - to_start_block_id): + if perm_block[i] not in illegal_blocks: + selected_random_blokcs.append(perm_block[i]) + if len(selected_random_blokcs) == num_rand_blocks: + break + return np.array(selected_random_blokcs, dtype=np.int32) + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->BigBird +class BigBirdSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BigBirdAttention(nn.Module): + def __init__(self, config, seed=None): + super().__init__() + self.attention_type = config.attention_type + self.config = config + self.seed = seed + + if self.config.attention_type == "original_full": + self.self = BigBirdSelfAttention(config) + elif self.config.attention_type == "block_sparse": + self.self = BigBirdBlockSparseAttention(config, seed) + else: + raise ValueError( + f"attention_type can either be original_full or block_sparse, but is {self.config.attention_type}" + ) + + self.output = BigBirdSelfOutput(config) + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + + self.attention_type = value + if value == "original_full": + # copy all weights to new full attention class + attn_weights = BigBirdSelfAttention(self.config) + else: + # copy all weights to new sparse attention class + attn_weights = BigBirdBlockSparseAttention(self.config, self.seed) + + attn_weights.query = self.self.query + attn_weights.value = self.self.value + attn_weights.key = self.self.key + self.self = attn_weights + self.attention_type = value + + if not self.training: + self.self.eval() + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + # block_sparse config + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + ): + + if self.attention_type == "original_full": + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + assert ( + encoder_hidden_states is None + ), "BigBird cannot be used as a decoder when config.attention_type != 'original_full'" + self_outputs = self.self( + hidden_states, band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask, output_attentions + ) + + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->BigBird +class BigBirdIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->BigBird +class BigBirdOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BigBirdLayer(nn.Module): + def __init__(self, config, seed=None): + super().__init__() + self.config = config + self.attention_type = config.attention_type + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BigBirdAttention(config, seed=seed) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" + self.crossattention = BigBirdAttention(config) + self.intermediate = BigBirdIntermediate(config) + self.output = BigBirdOutput(config) + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + self.attention_type = value + self.attention.set_attention_type(value) + + if self.add_cross_attention: + self.crossattention.set_attention_type(value) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + band_mask=None, + from_mask=None, + to_mask=None, + blocked_encoder_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=self_attn_past_key_value, + output_attentions=output_attentions, + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + from_blocked_mask=blocked_encoder_mask, + to_blocked_mask=blocked_encoder_mask, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with \ + cross-attention layers by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BigBirdEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.attention_type = config.attention_type + + self.layer = nn.ModuleList( + [BigBirdLayer(config, seed=layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + self.attention_type = value + for layer in self.layer: + layer.set_attention_type(value) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + band_mask=None, + from_mask=None, + to_mask=None, + blocked_encoder_mask=None, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + band_mask, + from_mask, + to_mask, + blocked_encoder_mask, + ) + else: + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + band_mask, + from_mask, + to_mask, + blocked_encoder_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->BigBird +class BigBirdPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->BigBird +class BigBirdLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BigBirdPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->BigBird +class BigBirdOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BigBirdLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->BigBird +class BigBirdOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +# Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->BigBird +class BigBirdPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BigBirdLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BigBirdPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BigBirdConfig + load_tf_weights = load_tf_weights_in_big_bird + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +BIG_BIRD_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config (:class:`~transformers.BigBirdConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +BIG_BIRD_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.BigBirdTokenizer`. See + :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@dataclass +class BigBirdForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.BigBirdtForPreTraining`. + + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@add_start_docstrings( + "The bare BigBird Model transformer outputting raw hidden-states without any specific head on top.", + BIG_BIRD_START_DOCSTRING, +) +class BigBirdModel(BigBirdPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration + set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.attention_type = self.config.attention_type + self.config = config + + self.block_size = self.config.block_size + + self.embeddings = BigBirdEmbeddings(config) + self.encoder = BigBirdEncoder(config) + + if add_pooling_layer: + self.pooler = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + else: + self.pooler = None + self.activation = None + + if self.attention_type != "original_full" and config.add_cross_attention: + logger.warning( + "When using `BigBirdForCausalLM` as decoder, then `attention_type` must be `original_full`. Setting `attention_type=original_full`" + ) + self.set_attention_type("original_full") + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + self.attention_type = value + self.encoder.set_attention_type(value) + + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # in order to use block_sparse attention, sequence_length has to be at least + # bigger than all global attentions: 2 * block_size + # + sliding tokens: 3 * block_size + # + random tokens: 2 * num_random_blocks * block_size + max_tokens_to_attend = (5 + 2 * self.config.num_random_blocks) * self.config.block_size + if self.attention_type == "block_sparse" and seq_length <= max_tokens_to_attend: + # change attention_type from block_sparse to original_full + sequence_length = input_ids.size(1) if input_ids is not None else inputs_embeds.size(1) + logger.warning( + "Attention type 'block_sparse' is not possible if sequence_length: " + f"{sequence_length} <= num global tokens: 2 * config.block_size " + "+ min. num sliding tokens: 3 * config.block_size " + "+ config.num_random_blocks * config.block_size " + "+ additional buffer: config.num_random_blocks * config.block_size " + f"= {max_tokens_to_attend} with config.block_size " + f"= {self.config.block_size}, config.num_random_blocks " + f"= {self.config.num_random_blocks}." + "Changing attention type to 'original_full'..." + ) + self.set_attention_type("original_full") + + if self.attention_type == "block_sparse": + ( + padding_len, + input_ids, + attention_mask, + token_type_ids, + position_ids, + inputs_embeds, + ) = self._pad_to_block_size( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pad_token_id=self.config.pad_token_id, + ) + else: + padding_len = 0 + + if self.attention_type == "block_sparse": + blocked_encoder_mask, band_mask, from_mask, to_mask = self.create_masks_for_block_sparse_attn( + attention_mask, self.block_size + ) + extended_attention_mask = None + + elif self.attention_type == "original_full": + blocked_encoder_mask = None + band_mask = None + from_mask = None + to_mask = None + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device + ) + else: + raise ValueError( + f"attention_type can either be original_full or block_sparse, but is {self.attention_type}" + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + blocked_encoder_mask=blocked_encoder_mask, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + pooler_output = self.activation(self.pooler(sequence_output[:, 0, :])) if (self.pooler is not None) else None + + # undo padding + if padding_len > 0: + # unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1) + sequence_output = sequence_output[:, :-padding_len] + + if not return_dict: + return (sequence_output, pooler_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooler_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + @staticmethod + def create_masks_for_block_sparse_attn(attention_mask: torch.Tensor, block_size: int): + + batch_size, seq_length = attention_mask.size() + assert ( + seq_length % block_size == 0 + ), f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block size is {block_size}." + + def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask): + """ + Create 3D attention mask from a 2D tensor mask. + + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + + Returns: + float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, + 3*to_block_size]. + """ + exp_blocked_to_pad = torch.cat( + [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], dim=2 + ) + band_mask = torch.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad) + band_mask.unsqueeze_(1) + return band_mask + + blocked_encoder_mask = attention_mask.view(batch_size, seq_length // block_size, block_size) + band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask) + + from_mask = attention_mask.view(batch_size, 1, seq_length, 1) + to_mask = attention_mask.view(batch_size, 1, 1, seq_length) + + return blocked_encoder_mask, band_mask, from_mask, to_mask + + def _pad_to_block_size( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor, + position_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + pad_token_id: int, + ): + """A helper function to pad tokens and mask to work with implementation of BigBird block-sparse attention.""" + # padding + block_size = self.config.block_size + + input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape + batch_size, seq_len = input_shape[:2] + + padding_len = (block_size - seq_len % block_size) % block_size + if padding_len > 0: + logger.info( + "Input ids are automatically padded from {} to {} to be a multiple of `config.block_size`: {}".format( + seq_len, seq_len + padding_len, block_size + ) + ) + if input_ids is not None: + input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) + if position_ids is not None: + # pad with position_id = pad_token_id as in modeling_bigbird.BigBirdEmbeddings + position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id) + if inputs_embeds is not None: + input_ids_padding = inputs_embeds.new_full( + (batch_size, padding_len), + self.config.pad_token_id, + dtype=torch.long, + ) + inputs_embeds_padding = self.embeddings(input_ids_padding) + inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) + + attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens + token_type_ids = F.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 + + return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds + + +class BigBirdForPreTraining(BigBirdPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BigBirdModel(config, add_pooling_layer=True) + self.cls = BigBirdPreTrainingHeads(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + next_sentence_label=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): + Labels for computing the next sequence prediction (classification) loss. If specified, nsp loss will be + added to masked_lm loss. Input should be a sequence pair (see :obj:`input_ids` docstring) Indices should be + in ``[0, 1]``: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + + Returns: + + Example:: + + >>> from transformers import BigBirdTokenizer, BigBirdForPreTraining + >>> import torch + + >>> tokenizer = BigBirdTokenizer.from_pretrained('bigbird-roberta-base') + >>> model = BigBirdForPreTraining.from_pretrained('bigbird-roberta-base') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + total_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if next_sentence_label is not None and total_loss is not None: + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = total_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BigBirdForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""BigBird Model with a `language modeling` head on top. """, BIG_BIRD_START_DOCSTRING) +class BigBirdForMaskedLM(BigBirdPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `BigBirdForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = BigBirdModel(config) + self.cls = BigBirdOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """BigBird Model with a `language modeling` head on top for CLM fine-tuning. """, BIG_BIRD_START_DOCSTRING +) +class BigBirdForCausalLM(BigBirdPreTrainedModel): + + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `BigBirdForCausalLM` as a standalone, add `is_decoder=True.`") + + self.bert = BigBirdModel(config) + self.cls = BigBirdOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + Returns: + + Example:: + + >>> from transformers import BigBirdTokenizer, BigBirdForCausalLM, BigBirdConfig + >>> import torch + + >>> tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base') + >>> config = BigBirdConfig.from_pretrained("google/bigbird-base") + >>> config.is_decoder = True + >>> model = BigBirdForCausalLM.from_pretrained('google/bigbird-roberta-base', config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +class BigBirdClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + self.config = config + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ + BigBird Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + BIG_BIRD_START_DOCSTRING, +) +class BigBirdForSequenceClassification(BigBirdPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.bert = BigBirdModel(config) + self.classifier = BigBirdClassificationHead(config) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + BigBird Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BIG_BIRD_START_DOCSTRING, +) +class BigBirdForMultipleChoice(BigBirdPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BigBirdModel(config) + self.sequence_summary = SequenceSummary(config) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.init_weights() + + @add_start_docstrings_to_model_forward( + BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., + num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See + :obj:`input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + pooled_output = self.sequence_summary(sequence_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + BigBird Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BIG_BIRD_START_DOCSTRING, +) +class BigBirdForTokenClassification(BigBirdPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BigBirdModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BigBirdForQuestionAnsweringHead(nn.Module): + """Head for question answering tasks.""" + + def __init__(self, config): + super().__init__() + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.intermediate = BigBirdIntermediate(config) + self.output = BigBirdOutput(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, encoder_output): + hidden_states = self.dropout(encoder_output) + hidden_states = self.intermediate(hidden_states) + hidden_states = self.output(hidden_states, encoder_output) + hidden_states = self.qa_outputs(hidden_states) + return hidden_states + + +@add_start_docstrings( + """ + BigBird Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BIG_BIRD_START_DOCSTRING, +) +class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + self.sep_token_id = config.sep_token_id + + self.bert = BigBirdModel(config, add_pooling_layer=False) + self.qa_classifier = BigBirdForQuestionAnsweringHead(config) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="google/bigbird-base-trivia-itc", + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + question_lengths=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + seqlen = input_ids.size(1) if input_ids is not None else inputs_embeds.size(1) + + if question_lengths is None and input_ids is not None: + # assuming input_ids format: context + question_lengths = torch.argmax(input_ids.eq(self.sep_token_id).int(), dim=-1) + 1 + question_lengths.unsqueeze_(1) + + logits_mask = None + if question_lengths is not None: + # setting lengths logits to `-infi` + logits_mask = self.prepare_question_mask(question_lengths, seqlen) + if token_type_ids is None: + token_type_ids = (~logits_mask).long() + logits_mask = logits_mask + logits_mask.unsqueeze_(2) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.qa_classifier(sequence_output) + + if logits_mask is not None: + # removing question tokens from the competition + logits = logits - logits_mask * 1e6 + + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def prepare_question_mask(q_lengths: torch.Tensor, maxlen: int): + # q_lengths -> (bz, 1) + mask = torch.arange(0, maxlen).to(q_lengths.device) + mask.unsqueeze_(0) # -> (1, maxlen) + mask = mask < q_lengths + return mask diff --git a/src/transformers/models/big_bird/tokenization_big_bird.py b/src/transformers/models/big_bird/tokenization_big_bird.py new file mode 100644 index 00000000000000..650f02dea169ae --- /dev/null +++ b/src/transformers/models/big_bird/tokenization_big_bird.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for BigBird.""" + + +import os +from shutil import copyfile +from typing import List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "google/bigbird-roberta-base": "https://huggingface.co/google/bigbird-roberta-base/resolve/main/spiece.model", + "google/bigbird-roberta-large": "https://huggingface.co/google/bigbird-roberta-large/resolve/main/spiece.model", + "google/bigbird-base-trivia-itc": "https://huggingface.co/google/bigbird-base-trivia-itc/resolve/main/spiece.model", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "google/bigbird-roberta-base": 4096, + "google/bigbird-roberta-large": 4096, + "google/bigbird-base-trivia-itc": 4096, +} + + +class BigBirdTokenizer(PreTrainedTokenizer): + """ + Construct a BigBird tokenizer. Based on `SentencePiece `__. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (:obj:`str`): + `SentencePiece `__ file (generally has a `.spm` extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The end of sequence token. + bos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The begin of sequence token. + unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for padding, for example when batching sequences of different lengths. + sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + prefix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sep_token="[SEP]", + mask_token="[MASK]", + cls_token="[CLS]", + **kwargs + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + sep_token=sep_token, + mask_token=mask_token, + cls_token=cls_token, + **kwargs, + ) + + self.vocab_file = vocab_file + + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(vocab_file) + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text, sample=False): + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + if not sample: + pieces = self.sp_model.EncodeAsPieces(text) + else: + pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) + return pieces + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = self.sp_model.decode_pieces(tokens) + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A Big Bird sequence has the following format: + + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 00a84b68107ddb..cf9109d3607fb1 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -613,6 +613,91 @@ def load_tf_weights_in_bert_generation(*args, **kwargs): requires_pytorch(load_tf_weights_in_bert_generation) +BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class BigBirdForCausalLM: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class BigBirdForMaskedLM: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class BigBirdForMultipleChoice: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class BigBirdForPreTraining: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class BigBirdForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class BigBirdForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class BigBirdForTokenClassification: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class BigBirdLayer: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class BigBirdModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class BigBirdPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +def load_tf_weights_in_big_bird(*args, **kwargs): + requires_pytorch(load_tf_weights_in_big_bird) + + BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/modeling_auto_mapping.py b/src/transformers/utils/modeling_auto_mapping.py index 45424f4f029c38..189b2e1959f4fd 100644 --- a/src/transformers/utils/modeling_auto_mapping.py +++ b/src/transformers/utils/modeling_auto_mapping.py @@ -6,6 +6,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ + ("BigBirdConfig", "BigBirdForQuestionAnswering"), ("ConvBertConfig", "ConvBertForQuestionAnswering"), ("LEDConfig", "LEDForQuestionAnswering"), ("DistilBertConfig", "DistilBertForQuestionAnswering"), diff --git a/tests/test_modeling_big_bird.py b/tests/test_modeling_big_bird.py new file mode 100644 index 00000000000000..4eb72128e3d8f0 --- /dev/null +++ b/tests/test_modeling_big_bird.py @@ -0,0 +1,906 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch BigBird model. """ + + +import unittest + +from tests.test_modeling_common import floats_tensor +from transformers import is_torch_available +from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer +from transformers.testing_utils import require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers import ( + MODEL_FOR_PRETRAINING_MAPPING, + BigBirdConfig, + BigBirdForCausalLM, + BigBirdForMaskedLM, + BigBirdForMultipleChoice, + BigBirdForPreTraining, + BigBirdForQuestionAnswering, + BigBirdForSequenceClassification, + BigBirdForTokenClassification, + BigBirdModel, + ) + from transformers.models.big_bird.modeling_big_bird import BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST + + +class BigBirdModelTester: + def __init__( + self, + parent, + batch_size=7, + seq_length=128, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu_fast", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=256, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + attention_type="block_sparse", + use_bias=True, + rescale_embeddings=False, + block_size=16, + num_rand_blocks=3, + position_embedding_type="absolute", + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + self.attention_type = attention_type + self.use_bias = use_bias + self.rescale_embeddings = rescale_embeddings + self.block_size = block_size + self.num_rand_blocks = num_rand_blocks + self.position_embedding_type = position_embedding_type + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = BigBirdConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_encoder_decoder=False, + initializer_range=self.initializer_range, + attention_type=self.attention_type, + use_bias=self.use_bias, + rescale_embeddings=self.rescale_embeddings, + block_size=self.block_size, + num_random_blocks=self.num_rand_blocks, + position_embedding_type=self.position_embedding_type, + ) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = BigBirdModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_pretraining( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = BigBirdForPreTraining(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=token_labels, + next_sentence_label=sequence_labels, + ) + self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, config.num_labels)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = BigBirdModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = BigBirdForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_for_masked_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = BigBirdForMaskedLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = BigBirdForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = BigBirdForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) + self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) + + def create_and_check_for_sequence_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = BigBirdForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_for_token_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = BigBirdForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def create_and_check_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = BigBirdForMultipleChoice(config=config) + model.to(torch_device) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + result = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + labels=choice_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + def create_and_check_for_auto_padding( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = BigBirdModel(config) + model.to(torch_device) + model.eval() + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_change_to_full_attn( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = BigBirdModel(config) + model.to(torch_device) + model.eval() + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + # the config should not be changed + self.parent.assertTrue(model.config.attention_type == "block_sparse") + + +@require_torch +class BigBirdModelTest(ModelTesterMixin, unittest.TestCase): + + # head masking & pruning is currently not supported for big bird + test_head_masking = False + test_pruning = False + + # torchscript should be possible, but takes prohibitively long to test. + # Also torchscript is not an important feature to have in the beginning. + test_torchscript = False + + all_model_classes = ( + ( + BigBirdModel, + BigBirdForPreTraining, + BigBirdForMaskedLM, + BigBirdForCausalLM, + BigBirdForMultipleChoice, + BigBirdForQuestionAnswering, + BigBirdForSequenceClassification, + BigBirdForTokenClassification, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (BigBirdForCausalLM,) if is_torch_available() else () + + # special case for ForPreTraining model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + inputs_dict["next_sentence_label"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + return inputs_dict + + def setUp(self): + self.model_tester = BigBirdModelTester(self) + self.config_tester = ConfigTester(self, config_class=BigBirdConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_pretraining(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_pretraining(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + def test_model_as_decoder(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + + def test_model_as_decoder_with_default_input_mask(self): + # This regression test was failing with PyTorch < 1.3 + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = self.model_tester.prepare_config_and_inputs_for_decoder() + + input_mask = None + + self.model_tester.create_and_check_model_as_decoder( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def test_retain_grad_hidden_states_attentions(self): + # bigbird cannot keep gradients in attentions when `attention_type=block_sparse` + + if self.model_tester.attention_type == "original_full": + super().test_retain_grad_hidden_states_attentions() + + @slow + def test_model_from_pretrained(self): + for model_name in BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = BigBirdForPreTraining.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_model_various_attn_type(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["original_full", "block_sparse"]: + config_and_inputs[0].attention_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Fast integration only compatible on GPU") + def test_fast_integration(self): + torch.manual_seed(0) + + input_ids = torch.randint( + self.model_tester.vocab_size, + (self.model_tester.batch_size, self.model_tester.seq_length), + device=torch_device, + ) + attention_mask = torch.ones((self.model_tester.batch_size, self.model_tester.seq_length), device=torch_device) + attention_mask[:, :-10] = 0 + token_type_ids = torch.randint( + self.model_tester.type_vocab_size, + (self.model_tester.batch_size, self.model_tester.seq_length), + device=torch_device, + ) + + config, _, _, _, _, _, _ = self.model_tester.prepare_config_and_inputs() + model = BigBirdModel(config).to(torch_device).eval() + + with torch.no_grad(): + hidden_states = model( + input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask + ).last_hidden_state + self.assertTrue( + torch.allclose( + hidden_states[0, 0, :5], + torch.tensor([-0.6326, 0.6124, -0.0844, 0.6698, -1.7155], device=torch_device), + atol=1e-3, + ) + ) + + def test_auto_padding(self): + self.model_tester.seq_length = 241 + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_auto_padding(*config_and_inputs) + + def test_for_change_to_full_attn(self): + self.model_tester.seq_length = 9 + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs) + + +@require_torch +@slow +class BigBirdModelIntegrationTest(unittest.TestCase): + # we can have this true once block_sparse attn_probs works accurately + test_attention_probs = False + + def _get_dummy_input_ids(self): + # fmt: off + ids = torch.tensor( + [[6, 117, 33, 36, 70, 22, 63, 31, 71, 72, 88, 58, 109, 49, 48, 116, 92, 6, 19, 95, 118, 100, 80, 111, 93, 2, 31, 84, 26, 5, 6, 82, 46, 96, 109, 4, 39, 19, 109, 13, 92, 31, 36, 90, 111, 18, 75, 6, 56, 74, 16, 42, 56, 92, 69, 108, 127, 81, 82, 41, 106, 19, 44, 24, 82, 121, 120, 65, 36, 26, 72, 13, 36, 98, 43, 64, 8, 53, 100, 92, 51, 122, 66, 17, 61, 50, 104, 127, 26, 35, 94, 23, 110, 71, 80, 67, 109, 111, 44, 19, 51, 41, 86, 71, 76, 44, 18, 68, 44, 77, 107, 81, 98, 126, 100, 2, 49, 98, 84, 39, 23, 98, 52, 46, 10, 82, 121, 73]], # noqa: E231 + dtype=torch.long, + device=torch_device, + ) + # fmt: on + return ids + + def test_inference_block_sparse_pretraining(self): + model = BigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base", attention_type="block_sparse") + model.to(torch_device) + + input_ids = torch.tensor([[20920, 232, 328, 1437] * 1024], dtype=torch.long, device=torch_device) + outputs = model(input_ids) + prediction_logits = outputs.prediction_logits + seq_relationship_logits = outputs.seq_relationship_logits + + self.assertEqual(prediction_logits.shape, torch.Size((1, 4096, 50358))) + self.assertEqual(seq_relationship_logits.shape, torch.Size((1, 2))) + + expected_prediction_logits_slice = torch.tensor( + [ + [-0.2420, -0.6048, -0.0614, 7.8422], + [-0.0596, -0.0104, -1.8408, 9.3352], + [1.0588, 0.7999, 5.0770, 8.7555], + [-0.1385, -1.7199, -1.7613, 6.1094], + ], + device=torch_device, + ) + self.assertTrue( + torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4) + ) + + expected_seq_relationship_logits = torch.tensor([[58.8196, 56.3629]], device=torch_device) + self.assertTrue(torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4)) + + def test_inference_full_pretraining(self): + model = BigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base", attention_type="original_full") + model.to(torch_device) + + input_ids = torch.tensor([[20920, 232, 328, 1437] * 512], dtype=torch.long, device=torch_device) + outputs = model(input_ids) + prediction_logits = outputs.prediction_logits + seq_relationship_logits = outputs.seq_relationship_logits + + self.assertEqual(prediction_logits.shape, torch.Size((1, 512 * 4, 50358))) + self.assertEqual(seq_relationship_logits.shape, torch.Size((1, 2))) + + expected_prediction_logits_slice = torch.tensor( + [ + [0.1499, -1.1217, 0.1990, 8.4499], + [-2.7757, -3.0687, -4.8577, 7.5156], + [1.5446, 0.1982, 4.3016, 10.4281], + [-1.3705, -4.0130, -3.9629, 5.1526], + ], + device=torch_device, + ) + self.assertTrue( + torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4) + ) + + expected_seq_relationship_logits = torch.tensor([[41.4503, 41.2406]], device=torch_device) + self.assertTrue(torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4)) + + def test_block_sparse_attention_probs(self): + """ + Asserting if outputted attention matrix is similar to hard coded attention matrix + """ + + if not self.test_attention_probs: + return + + model = BigBirdModel.from_pretrained( + "google/bigbird-roberta-base", attention_type="block_sparse", num_random_blocks=3, block_size=16 + ) + model.to(torch_device) + model.eval() + config = model.config + + input_ids = self._get_dummy_input_ids() + + hidden_states = model.embeddings(input_ids) + + batch_size, seqlen, _ = hidden_states.size() + attn_mask = torch.ones(batch_size, seqlen, device=torch_device, dtype=torch.float) + to_seq_length = from_seq_length = seqlen + from_block_size = to_block_size = config.block_size + + blocked_mask, band_mask, from_mask, to_mask = model.create_masks_for_block_sparse_attn( + attn_mask, config.block_size + ) + from_blocked_mask = to_blocked_mask = blocked_mask + + for i in range(config.num_hidden_layers): + pointer = model.encoder.layer[i].attention.self + + query_layer = pointer.transpose_for_scores(pointer.query(hidden_states)) + key_layer = pointer.transpose_for_scores(pointer.key(hidden_states)) + value_layer = pointer.transpose_for_scores(pointer.value(hidden_states)) + + context_layer, attention_probs = pointer.bigbird_block_sparse_attention( + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + pointer.num_attention_heads, + pointer.num_random_blocks, + pointer.attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_length, + to_seq_length, + seed=pointer.seed, + plan_from_length=None, + plan_num_rand_blocks=None, + output_attentions=True, + ) + + context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) + cl = torch.einsum("bhqk,bhkd->bhqd", attention_probs, value_layer) + cl = cl.view(context_layer.size()) + + self.assertTrue(torch.allclose(context_layer, cl, atol=0.001)) + + def test_block_sparse_context_layer(self): + model = BigBirdModel.from_pretrained( + "google/bigbird-roberta-base", attention_type="block_sparse", num_random_blocks=3, block_size=16 + ) + model.to(torch_device) + model.eval() + config = model.config + + input_ids = self._get_dummy_input_ids() + dummy_hidden_states = model.embeddings(input_ids) + + attn_mask = torch.ones_like(input_ids, device=torch_device) + blocked_mask, band_mask, from_mask, to_mask = model.create_masks_for_block_sparse_attn( + attn_mask, config.block_size + ) + targeted_cl = torch.tensor( + [ + [0.1874, 1.5260, 0.2335, -0.0473, -0.0961, 1.8384, -0.0141, 0.1250, 0.0085, -0.0048], + [-0.0554, 0.0728, 0.1683, -0.1332, 0.1741, 0.1337, -0.2380, -0.1849, -0.0390, -0.0259], + [-0.0419, 0.0767, 0.1591, -0.1399, 0.1789, 0.1257, -0.2406, -0.1772, -0.0261, -0.0079], + [0.1860, 1.5172, 0.2326, -0.0473, -0.0953, 1.8291, -0.0147, 0.1245, 0.0082, -0.0046], + [0.1879, 1.5296, 0.2335, -0.0471, -0.0975, 1.8433, -0.0136, 0.1260, 0.0086, -0.0054], + [0.1854, 1.5147, 0.2334, -0.0480, -0.0956, 1.8250, -0.0149, 0.1222, 0.0082, -0.0060], + [0.1859, 1.5184, 0.2334, -0.0474, -0.0955, 1.8297, -0.0143, 0.1234, 0.0079, -0.0054], + [0.1885, 1.5336, 0.2335, -0.0467, -0.0979, 1.8481, -0.0130, 0.1269, 0.0085, -0.0049], + [0.1881, 1.5305, 0.2335, -0.0471, -0.0976, 1.8445, -0.0135, 0.1262, 0.0086, -0.0053], + [0.1852, 1.5148, 0.2333, -0.0480, -0.0949, 1.8254, -0.0151, 0.1225, 0.0079, -0.0055], + [0.1877, 1.5292, 0.2335, -0.0470, -0.0972, 1.8431, -0.0135, 0.1259, 0.0084, -0.0052], + [0.1874, 1.5261, 0.2334, -0.0472, -0.0968, 1.8393, -0.0140, 0.1251, 0.0084, -0.0052], + [0.1853, 1.5151, 0.2331, -0.0478, -0.0948, 1.8256, -0.0154, 0.1228, 0.0086, -0.0052], + [0.1867, 1.5233, 0.2334, -0.0475, -0.0965, 1.8361, -0.0139, 0.1247, 0.0084, -0.0054], + ], + device=torch_device, + ) + + context_layer = model.encoder.layer[0].attention.self( + dummy_hidden_states, + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + from_blocked_mask=blocked_mask, + to_blocked_mask=blocked_mask, + ) + context_layer = context_layer[0] + + self.assertEqual(context_layer.shape, torch.Size((1, 128, 768))) + self.assertTrue(torch.allclose(context_layer[0, 64:78, 300:310], targeted_cl, atol=0.0001)) + + def test_tokenizer_inference(self): + tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base") + model = BigBirdModel.from_pretrained( + "google/bigbird-roberta-base", attention_type="block_sparse", num_random_blocks=3, block_size=16 + ) + model.to(torch_device) + + text = [ + 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth ... This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth ,, I was born in 92000, and this is falsé.' + ] + inputs = tokenizer(text) + + for k in inputs: + inputs[k] = torch.tensor(inputs[k], device=torch_device, dtype=torch.long) + + prediction = model(**inputs) + prediction = prediction[0] + + self.assertEqual(prediction.shape, torch.Size((1, 128, 768))) + + expected_prediction = torch.tensor( + [ + [-0.0745, 0.0689, -0.1126, -0.0610], + [-0.0343, 0.0111, -0.0269, -0.0858], + [0.1150, 0.0896, 0.0492, 0.0149], + [-0.0657, 0.2035, 0.0444, -0.0535], + [0.1143, 0.0465, 0.1583, -0.1855], + [-0.0216, 0.0807, 0.0536, 0.1371], + [-0.1879, 0.0097, -0.1916, 0.1701], + [0.7616, 0.1240, 0.0669, 0.2588], + [0.1096, -0.1810, -0.1987, 0.0445], + [0.1810, -0.3608, -0.0081, 0.1764], + [-0.0472, 0.0460, 0.0976, -0.0021], + [-0.0274, -0.3274, -0.0788, 0.0465], + ], + device=torch_device, + ) + self.assertTrue(torch.allclose(prediction[0, 52:64, 320:324], expected_prediction, atol=1e-4)) + + def test_inference_question_answering(self): + tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-base-trivia-itc") + model = BigBirdForQuestionAnswering.from_pretrained( + "google/bigbird-base-trivia-itc", attention_type="block_sparse", block_size=16, num_random_blocks=3 + ) + model.to(torch_device) + + context = "🤗 Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch. Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a question answering dataset is the SQuAD dataset" + + question = [ + "How many pretrained models are available in 🤗 Transformers?", + "🤗 Transformers provides interoperability between which frameworks?", + ] + inputs = tokenizer( + question, + [context, context], + padding=True, + return_tensors="pt", + add_special_tokens=True, + max_length=128, + truncation=True, + ) + + inputs = {k: v.to(torch_device) for k, v in inputs.items()} + + start_logits, end_logits = model(**inputs).to_tuple() + + # fmt: off + target_start_logits = torch.tensor( + [[-9.5889, -10.2121, -14.2158, -11.1457, -10.7376, -7.3907, -10.2084, -9.5659, -15.0336, -8.6686, -9.1737, -11.1457, -13.4722, -6.3336, -9.6311, -8.4821, -15.141, -9.1226, -10.3328, -11.1457, -6.6793, -3.9627, 2.7126, -5.5607, -8.4625, -12.499, -11.4757, -9.6334, -4.0565, -10.0474, -7.4126, -13.5669], [-15.3796, -12.6863, -10.3951, -7.6706, -10.1808, -11.4401, -15.5868, -12.7959, -11.0186, -12.6863, -14.2198, -8.1182, -11.1353, -11.6512, -15.702, -12.8964, -12.5173, -12.6863, -14.4133, -13.1532, -12.2846, -14.1572, -11.2747, -11.1159, -11.5219, -13.1115, -11.8779, -13.989, -11.5234, -15.0459, -10.0178, -12.9253]], # noqa: E231 + device=torch_device, + ) + target_end_logits = torch.tensor( + [[-12.4895, -10.9826, -13.8226, -11.9922, -13.2647, -12.4584, -10.6143, -9.4091, -16.844, -14.0393, -9.5914, -11.9922, -15.5142, -11.4073, -10.1064, -8.3961, -16.4374, -13.9323, -10.791, -11.9922, -8.736, -9.5672, 0.2844, -4.0976, -13.849, -11.8035, -12.7784, -14.1314, -7.4138, -10.5488, -8.0133, -14.8779], [-14.9831, -13.4818, -13.1566, -12.7259, -10.5892, -10.8605, -17.2376, -15.9398, -12.8739, -13.4818, -16.6979, -13.3403, -11.6416, -11.392, -16.9553, -15.723, -13.2643, -13.4818, -16.2067, -15.6688, -15.0449, -15.1253, -15.1373, -12.385, -13.3652, -15.9473, -14.9587, -15.5024, -13.1482, -16.6358, -12.3908, -15.7493]], # noqa: E231 + device=torch_device, + ) + # fmt: on + + self.assertTrue(torch.allclose(start_logits[:, 64:96], target_start_logits, atol=1e-4)) + self.assertTrue(torch.allclose(end_logits[:, 64:96], target_end_logits, atol=1e-4)) + + input_ids = inputs["input_ids"].tolist() + answer = [ + input_ids[i][torch.argmax(start_logits, dim=-1)[i] : torch.argmax(end_logits, dim=-1)[i] + 1] + for i in range(len(input_ids)) + ] + answer = tokenizer.batch_decode(answer) + + self.assertTrue(answer == ["32", "[SEP]"]) + + def test_fill_mask(self): + tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base") + model = BigBirdForMaskedLM.from_pretrained("google/bigbird-roberta-base") + model.to(torch_device) + + input_ids = tokenizer("The goal of life is [MASK] .", return_tensors="pt").input_ids.to(torch_device) + logits = model(input_ids).logits + + # [MASK] is token at 6th position + pred_token = tokenizer.decode(torch.argmax(logits[0, 6:7], axis=-1)) + self.assertEqual(pred_token, "happiness") + + def test_auto_padding(self): + model = BigBirdModel.from_pretrained( + "google/bigbird-roberta-base", attention_type="block_sparse", num_random_blocks=3, block_size=16 + ) + model.to(torch_device) + model.eval() + + input_ids = torch.tensor([200 * [10] + 40 * [2] + [1]], device=torch_device, dtype=torch.long) + output = model(input_ids).to_tuple()[0] + + # fmt: off + target = torch.tensor( + [[-0.045136, -0.068013, 0.12246, -0.01356, 0.018386, 0.025333, -0.0044439, -0.0030996, -0.064031, 0.0006439], [-0.045018, -0.067638, 0.12317, -0.013998, 0.019216, 0.025695, -0.0043705, -0.0031895, -0.063153, 0.00088899], [-0.045042, -0.067305, 0.1234, -0.014512, 0.020057, 0.026084, -0.004615, -0.0031728, -0.062442, 0.0010263], [-0.044589, -0.067655, 0.12416, -0.014287, 0.019416, 0.026065, -0.0050958, -0.002702, -0.063158, 0.0004827], [-0.044627, -0.067535, 0.1239, -0.014319, 0.019491, 0.026213, -0.0059482, -0.0025906, -0.063116, 0.00014669], [-0.044899, -0.067704, 0.12337, -0.014231, 0.019256, 0.026345, -0.0065565, -0.0022938, -0.063433, -0.00011409], [-0.045599, -0.067764, 0.12235, -0.014151, 0.019206, 0.026417, -0.0068965, -0.0024494, -0.063313, -4.4499e-06], [-0.045557, -0.068372, 0.12199, -0.013747, 0.017962, 0.026103, -0.0070607, -0.0023552, -0.06447, -0.00048756], [-0.045334, -0.068913, 0.1217, -0.013566, 0.01693, 0.025745, -0.006311, -0.0024903, -0.065575, -0.0006719], [-0.045171, -0.068726, 0.12164, -0.013688, 0.017139, 0.025629, -0.005213, -0.0029412, -0.065237, -0.00020669], [-0.044411, -0.069267, 0.12206, -0.013645, 0.016212, 0.025589, -0.0044121, -0.002972, -0.066277, -0.00067963], [-0.043487, -0.069792, 0.1232, -0.013663, 0.015303, 0.02613, -0.0036294, -0.0030616, -0.067483, -0.0012642], [-0.042622, -0.069287, 0.12469, -0.013936, 0.016204, 0.026474, -0.0040534, -0.0027365, -0.066994, -0.0014148], [-0.041879, -0.070031, 0.12593, -0.014047, 0.015082, 0.027751, -0.0040683, -0.0027189, -0.068985, -0.0027146]], # noqa: E231 + device=torch_device, + ) + # fmt: on + + self.assertEqual(output.shape, torch.Size((1, 241, 768))) + self.assertTrue(torch.allclose(output[0, 64:78, 300:310], target, atol=0.0001)) diff --git a/tests/test_tokenization_big_bird.py b/tests/test_tokenization_big_bird.py new file mode 100644 index 00000000000000..967ef510bad430 --- /dev/null +++ b/tests/test_tokenization_big_bird.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import unittest + +from transformers import BigBirdTokenizer +from transformers.file_utils import cached_property +from transformers.testing_utils import require_sentencepiece, require_torch, slow + +from .test_tokenization_common import TokenizerTesterMixin + + +SPIECE_UNDERLINE = "▁" + +SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") + + +@require_sentencepiece +class BigBirdTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = BigBirdTokenizer + + def setUp(self): + super().setUp() + + tokenizer = BigBirdTokenizer(SAMPLE_VOCAB, keep_accents=True) + tokenizer.save_pretrained(self.tmpdirname) + + def test_full_tokenizer(self): + tokenizer = BigBirdTokenizer(SAMPLE_VOCAB, keep_accents=True) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), + [285, 46, 10, 170, 382], + ) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "9", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "é", + ".", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, + [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4], + ) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "", + ".", + ], + ) + + @cached_property + def big_tokenizer(self): + return BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base") + + @slow + def test_tokenization_base_easy_symbols(self): + symbols = "Hello World!" + original_tokenizer_encodings = [65, 18536, 2260, 101, 66] + + self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols)) + + @slow + def test_tokenization_base_hard_symbols(self): + symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth' + # fmt: off + original_tokenizer_encodings = [65, 871, 419, 358, 946, 991, 2521, 452, 358, 1357, 387, 7751, 3536, 112, 985, 456, 126, 865, 938, 5400, 5734, 458, 1368, 467, 786, 2462, 5246, 1159, 633, 865, 4519, 457, 582, 852, 2557, 427, 916, 508, 405, 34324, 497, 391, 408, 11342, 1244, 385, 100, 938, 985, 456, 574, 362, 12597, 3200, 3129, 1172, 66] # noqa: E231 + # fmt: on + self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols)) + + @require_torch + @slow + def test_torch_encode_plus_sent_to_model(self): + import torch + + from transformers import BigBirdConfig, BigBirdModel + + # Build sequence + first_ten_tokens = list(self.big_tokenizer.get_vocab().keys())[:10] + sequence = " ".join(first_ten_tokens) + encoded_sequence = self.big_tokenizer.encode_plus(sequence, return_tensors="pt", return_token_type_ids=False) + batch_encoded_sequence = self.big_tokenizer.batch_encode_plus( + [sequence + " " + sequence], return_tensors="pt", return_token_type_ids=False + ) + + config = BigBirdConfig(attention_type="original_full") + model = BigBirdModel(config) + + assert model.get_input_embeddings().weight.shape[0] >= self.big_tokenizer.vocab_size + + with torch.no_grad(): + model(**encoded_sequence) + model(**batch_encoded_sequence) + + @slow + def test_special_tokens(self): + """ + To reproduce: + + $ wget https://github.com/google-research/bigbird/blob/master/bigbird/vocab/gpt2.model?raw=true + $ mv gpt2.model?raw=true gpt2.model + + ``` + import tensorflow_text as tft + import tensorflow as tf + + vocab_model_file = "./gpt2.model" + tokenizer = tft.SentencepieceTokenizer(model=tf.io.gfile.GFile(vocab_model_file, "rb").read())) + ids = tokenizer.tokenize("Paris is the [MASK].") + ids = tf.concat([tf.constant([65]), ids, tf.constant([66])], axis=0) + detokenized = tokenizer.detokenize(ids) # should give [CLS] Paris is the [MASK].[SEP] + """ + tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base") + decoded_text = tokenizer.decode(tokenizer("Paris is the [MASK].").input_ids) + + self.assertTrue(decoded_text == "[CLS] Paris is the [MASK].[SEP]")