diff --git a/README.md b/README.md
index fa665f808054b3..2740feecc4591c 100644
--- a/README.md
+++ b/README.md
@@ -417,6 +417,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (from Alibaba Research) released with the paper [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) by Peng Wang, Cheng Da, and Cong Yao.
1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
+1. **[Mixtral](https://huggingface.co/docs/transformers/main/model_doc/mixtral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli.
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou.
diff --git a/README_es.md b/README_es.md
index dfed9033276844..9cdbd351ce50f7 100644
--- a/README_es.md
+++ b/README_es.md
@@ -392,6 +392,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (from Alibaba Research) released with the paper [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) by Peng Wang, Cheng Da, and Cong Yao.
1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The Mistral AI team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed..
+1. **[Mixtral](https://huggingface.co/docs/transformers/main/model_doc/mixtral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli.
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou.
diff --git a/README_hd.md b/README_hd.md
index 636d02b99a1e44..1e9adbc02ed8d1 100644
--- a/README_hd.md
+++ b/README_hd.md
@@ -366,6 +366,7 @@ conda install -c huggingface transformers
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (NVIDIA से) साथ वाला पेपर [Megatron-LM: ट्रेनिंग मल्टी-बिलियन पैरामीटर लैंग्वेज मॉडल्स यूजिंग मॉडल पैरेललिज़्म] (https://arxiv.org/abs/1909.08053) मोहम्मद शोएबी, मोस्टोफा पटवारी, राउल पुरी, पैट्रिक लेग्रेस्ले, जेरेड कैस्पर और ब्रायन कैटानज़ारो द्वारा पोस्ट किया गया।
1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (Alibaba Research से) Peng Wang, Cheng Da, and Cong Yao. द्वाराअनुसंधान पत्र [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) के साथ जारी किया गया
1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The Mistral AI team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed..
+1. **[Mixtral](https://huggingface.co/docs/transformers/main/model_doc/mixtral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (फ्रॉम Studio Ousia) साथ में पेपर [mLUKE: द पावर ऑफ एंटिटी रिप्रेजेंटेशन इन मल्टीलिंगुअल प्रीट्रेन्ड लैंग्वेज मॉडल्स](https://arxiv.org/abs/2110.08151) रयोकन री, इकुया यामाडा, और योशिमासा त्सुरोका द्वारा।
1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (Facebook से) Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli. द्वाराअनुसंधान पत्र [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) के साथ जारी किया गया
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (सीएमयू/गूगल ब्रेन से) साथ में कागज [मोबाइलबर्ट: संसाधन-सीमित उपकरणों के लिए एक कॉम्पैक्ट टास्क-अज्ञेय बीईआरटी] (https://arxiv.org/abs/2004.02984) Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, और Denny Zhou द्वारा पोस्ट किया गया।
diff --git a/README_ja.md b/README_ja.md
index 03b2f84b038cd7..cd40d4f4b9939a 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -426,6 +426,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (NVIDIA から) Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro から公開された研究論文: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (Alibaba Research から) Peng Wang, Cheng Da, and Cong Yao. から公開された研究論文 [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592)
1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The Mistral AI team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed..
+1. **[Mixtral](https://huggingface.co/docs/transformers/main/model_doc/mixtral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (Studio Ousia から) Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka から公開された研究論文: [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151)
1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (Facebook から) Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli. から公開された研究論文 [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516)
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (CMU/Google Brain から) Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou から公開された研究論文: [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984)
diff --git a/README_ko.md b/README_ko.md
index 3cfe66be056d9a..344ecabb871485 100644
--- a/README_ko.md
+++ b/README_ko.md
@@ -341,6 +341,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (NVIDIA 에서) Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 의 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 논문과 함께 발표했습니다.
1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (Alibaba Research 에서 제공)은 Peng Wang, Cheng Da, and Cong Yao.의 [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592)논문과 함께 발표했습니다.
1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The Mistral AI team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed..
+1. **[Mixtral](https://huggingface.co/docs/transformers/main/model_doc/mixtral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (Studio Ousia 에서) Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka 의 [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) 논문과 함께 발표했습니다.
1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (Facebook 에서 제공)은 Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli.의 [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516)논문과 함께 발표했습니다.
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (CMU/Google Brain 에서) Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou 의 [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) 논문과 함께 발표했습니다.
diff --git a/README_zh-hans.md b/README_zh-hans.md
index a555d8d05dc589..d24cc81b25311f 100644
--- a/README_zh-hans.md
+++ b/README_zh-hans.md
@@ -365,6 +365,7 @@ conda install -c huggingface transformers
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (来自 NVIDIA) 伴随论文 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 由 Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 发布。
1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (来自 Alibaba Research) 伴随论文 [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) 由 Peng Wang, Cheng Da, and Cong Yao 发布。
1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The Mistral AI team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed..
+1. **[Mixtral](https://huggingface.co/docs/transformers/main/model_doc/mixtral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (来自 Studio Ousia) 伴随论文 [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) 由 Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka 发布。
1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (来自 Facebook) 伴随论文 [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) 由 Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli 发布。
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (来自 CMU/Google Brain) 伴随论文 [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) 由 Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou 发布。
diff --git a/README_zh-hant.md b/README_zh-hant.md
index 02eca0c3e7026e..270a2af3730844 100644
--- a/README_zh-hant.md
+++ b/README_zh-hant.md
@@ -377,6 +377,7 @@ conda install -c huggingface transformers
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
1. **[MGP-STR](https://huggingface.co/docs/transformers/model_doc/mgp-str)** (from Alibaba Research) released with the paper [Multi-Granularity Prediction for Scene Text Recognition](https://arxiv.org/abs/2209.03592) by Peng Wang, Cheng Da, and Cong Yao.
1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The Mistral AI team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed..
+1. **[Mixtral](https://huggingface.co/docs/transformers/main/model_doc/mixtral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli.
1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou.
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index d6312509c49e10..f221de6d38dda2 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -396,6 +396,8 @@
title: MegatronGPT2
- local: model_doc/mistral
title: Mistral
+ - local: model_doc/mixtral
+ title: Mixtral
- local: model_doc/mluke
title: mLUKE
- local: model_doc/mobilebert
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 89a80df5b83787..796e4ca98e587e 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -189,6 +189,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Megatron-GPT2](model_doc/megatron_gpt2) | ✅ | ✅ | ✅ |
| [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ |
| [Mistral](model_doc/mistral) | ✅ | ❌ | ❌ |
+| [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ |
| [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ |
| [MMS](model_doc/mms) | ✅ | ✅ | ✅ |
| [MobileBERT](model_doc/mobilebert) | ✅ | ✅ | ❌ |
diff --git a/docs/source/en/model_doc/mixtral.md b/docs/source/en/model_doc/mixtral.md
new file mode 100644
index 00000000000000..a0c5b157186426
--- /dev/null
+++ b/docs/source/en/model_doc/mixtral.md
@@ -0,0 +1,163 @@
+
+
+# Mixtral
+
+## Overview
+
+Mixtral-8x7B is Mistral AI's second Large Language Model (LLM).
+
+The Mixtral model was proposed in the by the [Mistral AI](https://mistral.ai/) team.
+
+It was introduced in the [Mixtral of Experts blogpost](https://mistral.ai/news/mixtral-of-experts/) with the following introduction:
+
+*Today, the team is proud to release Mixtral 8x7B, a high-quality sparse mixture of experts models (SMoE) with open weights. Licensed under Apache 2.0. Mixtral outperforms Llama 2 70B on most benchmarks with 6x faster inference. It is the strongest open-weight model with a permissive license and the best model overall regarding cost/performance trade-offs. In particular, it matches or outperforms GPT3.5 on most standard benchmarks.*
+
+Tips:
+
+
+- The model needs to be converted using the [conversion script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py).
+- If the model is quantized to 4bits, a single A100 is enough to fit the entire 84B model.
+
+This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada) and [Arthur Zucker](https://huggingface.co/ArthurZ) .
+The original code can be found [here](https://github.com/mistralai/mistral-src).
+
+
+### Model Details
+
+Mixtral-84B is a decoder-based LM with the following architectural choices:
+
+* Mixtral is a Mixture of Expert (MOE) model with 8 experts per MLP, with a total of 85B paramateres but the compute required is the same as a 14B model. This is because even though each experts have to be loaded in RAM (70B like ram requirement) each token from the hidden states are dipatched twice (top 2 routing) and thus the compute (the operation required at each foward computation) is just 2 X sequence_length.
+
+The following implementation details are shared with Mistral AI's first model [mistral](~models/doc/mistral):
+* Sliding Window Attention - Trained with 8k context length and fixed cache size, with a theoretical attention span of 128K tokens
+* GQA (Grouped Query Attention) - allowing faster inference and lower cache size.
+* Byte-fallback BPE tokenizer - ensures that characters are never mapped to out of vocabulary tokens.
+
+They also provide an instruction fine-tuned model: `mistralai/Mixtral-8x7B-v0.1` which can be used for chat-based inference.
+
+For more details please read our [release blog post](https://mistral.ai/news/mixtral-of-experts/)
+
+### License
+
+`Mixtral-8x7B` is released under the Apache 2.0 license.
+
+## Usage tips
+
+`Mixtral-8x7B` can be found on the [Huggingface Hub](https://huggingface.co/mistralai)
+
+These ready-to-use checkpoints can be downloaded and used via the HuggingFace Hub:
+
+```python
+>>> from transformers import AutoModelForCausalLM, AutoTokenizer
+>>> device = "cuda" # the device to load the model onto
+
+>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
+>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-8x7B")
+
+>>> prompt = "My favourite condiment is"
+
+>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
+>>> model.to(device)
+
+>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
+>>> tokenizer.batch_decode(generated_ids)[0]
+"The expected output"
+```
+
+To use the raw checkpoints with HuggingFace you can use the `convert_mixtral_weights_to_hf.py` script to convert them to the HuggingFace format:
+
+```bash
+python src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py \
+ --input_dir /path/to/downloaded/mistral/weights --output_dir /output/path
+```
+
+You can then load the converted model from the `output/path`:
+
+```python
+from transformers import MixtralForCausalLM, LlamaTokenizer
+
+tokenizer = LlamaTokenizer.from_pretrained("/output/path")
+model = MixtralForCausalLM.from_pretrained("/output/path")
+```
+
+## Combining Mixtral and Flash Attention 2
+
+First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
+
+```bash
+pip install -U flash-attn --no-build-isolation
+```
+
+Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
+
+To load and run a model using Flash Attention 2, refer to the snippet below:
+
+```python
+>>> import torch
+>>> from transformers import AutoModelForCausalLM, AutoTokenizer
+>>> device = "cuda" # the device to load the model onto
+
+>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
+>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
+
+>>> prompt = "My favourite condiment is"
+
+>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
+>>> model.to(device)
+
+>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
+>>> tokenizer.batch_decode(generated_ids)[0]
+"The expected output"
+```
+
+### Expected speedups
+
+Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `mistralai/Mixtral-8x7B-v0.1` checkpoint and the Flash Attention 2 version of the model.
+
+
+
+
+
+### Sliding window Attention
+
+The current implementation supports the sliding window attention mechanism and memory efficient cache management.
+To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).
+
+The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.
+
+## The Mistral Team
+
+Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
+
+## MixtralConfig
+
+[[autodoc]] MixtralConfig
+
+## MixtralModel
+
+[[autodoc]] MixtralModel
+ - forward
+
+## MixtralForCausalLM
+
+[[autodoc]] MixtralForCausalLM
+ - forward
+
+## MixtralForSequenceClassification
+
+[[autodoc]] MixtralForSequenceClassification
+ - forward
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index b12670584a4ec5..f194cf2db952b1 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -48,6 +48,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
+* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
diff --git a/docs/source/en/tasks/language_modeling.md b/docs/source/en/tasks/language_modeling.md
index 2eac6ec12328a1..a50555dfcf941a 100644
--- a/docs/source/en/tasks/language_modeling.md
+++ b/docs/source/en/tasks/language_modeling.md
@@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
Choose one of the following architectures:
-[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
+[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
diff --git a/docs/source/en/tasks/sequence_classification.md b/docs/source/en/tasks/sequence_classification.md
index c6daa66f362f2a..4a0e5b611c9136 100644
--- a/docs/source/en/tasks/sequence_classification.md
+++ b/docs/source/en/tasks/sequence_classification.md
@@ -33,7 +33,7 @@ The task illustrated in this tutorial is supported by the following model archit
-[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [CodeLlama](../model_doc/code_llama), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
+[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [CodeLlama](../model_doc/code_llama), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 5e2e6ea13fd123..3b03c606bb3253 100644
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -595,6 +595,7 @@
"MgpstrTokenizer",
],
"models.mistral": ["MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MistralConfig"],
+ "models.mixtral": ["MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MixtralConfig"],
"models.mluke": [],
"models.mobilebert": [
"MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
@@ -2584,6 +2585,9 @@
"MistralPreTrainedModel",
]
)
+ _import_structure["models.mixtral"].extend(
+ ["MixtralForCausalLM", "MixtralForSequenceClassification", "MixtralModel", "MixtralPreTrainedModel"]
+ )
_import_structure["models.mobilebert"].extend(
[
"MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -5234,6 +5238,7 @@
MgpstrTokenizer,
)
from .models.mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig
+ from .models.mixtral import MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MixtralConfig
from .models.mobilebert import (
MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
MobileBertConfig,
@@ -7006,6 +7011,12 @@
MistralModel,
MistralPreTrainedModel,
)
+ from .models.mixtral import (
+ MixtralForCausalLM,
+ MixtralForSequenceClassification,
+ MixtralModel,
+ MixtralPreTrainedModel,
+ )
from .models.mobilebert import (
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MobileBertForMaskedLM,
diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py
index aceec7abd40643..cbee6a292b531b 100755
--- a/src/transformers/modeling_outputs.py
+++ b/src/transformers/modeling_outputs.py
@@ -368,6 +368,97 @@ class MoEModelOutput(ModelOutput):
router_probs: Optional[Tuple[torch.FloatTensor]] = None
+@dataclass
+class MoeModelOutputWithPast(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
+ loss for Mixture of Experts models.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoeCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) with mixture of experts outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+
+ aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+ aux_loss for the sparse modules.
+
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
+ loss for Mixture of Experts models.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ aux_loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
+
+
@dataclass
class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 4442caa811575e..d14f385b45c2c2 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -136,6 +136,7 @@
megatron_gpt2,
mgp_str,
mistral,
+ mixtral,
mluke,
mobilebert,
mobilenet_v1,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 8483b6e8e89e88..9ee67e0a63f797 100755
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -144,6 +144,7 @@
("megatron-bert", "MegatronBertConfig"),
("mgp-str", "MgpstrConfig"),
("mistral", "MistralConfig"),
+ ("mixtral", "MixtralConfig"),
("mobilebert", "MobileBertConfig"),
("mobilenet_v1", "MobileNetV1Config"),
("mobilenet_v2", "MobileNetV2Config"),
@@ -365,6 +366,7 @@
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mgp-str", "MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mistral", "MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
+ ("mixtral", "MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mobilenet_v1", "MOBILENET_V1_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mobilenet_v2", "MOBILENET_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mobilevit", "MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -598,6 +600,7 @@
("megatron_gpt2", "Megatron-GPT2"),
("mgp-str", "MGP-STR"),
("mistral", "Mistral"),
+ ("mixtral", "Mixtral"),
("mluke", "mLUKE"),
("mms", "MMS"),
("mobilebert", "MobileBERT"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 9f1b509187ae81..b9fe29cb1bbc20 100755
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -143,6 +143,7 @@
("megatron-bert", "MegatronBertModel"),
("mgp-str", "MgpstrForSceneTextRecognition"),
("mistral", "MistralModel"),
+ ("mixtral", "MixtralModel"),
("mobilebert", "MobileBertModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),
@@ -431,6 +432,7 @@
("mega", "MegaForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"),
("mistral", "MistralForCausalLM"),
+ ("mixtral", "MixtralForCausalLM"),
("mpt", "MptForCausalLM"),
("musicgen", "MusicgenForCausalLM"),
("mvp", "MvpForCausalLM"),
@@ -768,6 +770,7 @@
("mega", "MegaForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"),
("mistral", "MistralForSequenceClassification"),
+ ("mixtral", "MixtralForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
("mpnet", "MPNetForSequenceClassification"),
("mpt", "MptForSequenceClassification"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 81e17c8d61135f..91013ab3a701a1 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -240,6 +240,13 @@
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
+ (
+ "mixtral",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
diff --git a/src/transformers/models/mixtral/__init__.py b/src/transformers/models/mixtral/__init__.py
new file mode 100644
index 00000000000000..ebde04ea4ae81c
--- /dev/null
+++ b/src/transformers/models/mixtral/__init__.py
@@ -0,0 +1,62 @@
+# Copyright 2023 Mixtral AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_torch_available,
+)
+
+
+_import_structure = {
+ "configuration_mixtral": ["MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MixtralConfig"],
+}
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_mixtral"] = [
+ "MixtralForCausalLM",
+ "MixtralModel",
+ "MixtralPreTrainedModel",
+ "MixtralForSequenceClassification",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_mixtral import MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MixtralConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_mixtral import (
+ MixtralForCausalLM,
+ MixtralForSequenceClassification,
+ MixtralModel,
+ MixtralPreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/mixtral/configuration_mixtral.py b/src/transformers/models/mixtral/configuration_mixtral.py
new file mode 100644
index 00000000000000..dc547068e02f42
--- /dev/null
+++ b/src/transformers/models/mixtral/configuration_mixtral.py
@@ -0,0 +1,169 @@
+# coding=utf-8
+# Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Mixtral model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "mistral-ai/Mixtral-8x7B": "https://huggingface.co/mistral-ai/Mixtral-8x7B/resolve/main/config.json",
+}
+
+
+class MixtralConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
+ Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1.
+
+ [mixtralai/Mixtral-8x7B](https://huggingface.co/mixtralai/Mixtral-8x7B)
+ [mixtralai/Mixtral-7B-Instruct-v0.1](https://huggingface.co/mixtralai/Mixtral-7B-Instruct-v0.1)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MixtralModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 14336):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
+ The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
+ allows sequence of up to 4096*32 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ The id of the padding token.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ sliding_window (`int`, *optional*, defaults to 4096):
+ Sliding window attention window size. If not specified, will default to `4096`.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
+ The number of experts to root per-token, can be also interpreted as the `top-p` routing
+ parameter
+ num_local_experts (`int`, *optional*, defaults to 8):
+ Number of experts per Sparse MLP layer.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabeling this will also
+ allow the model to output the auxiliary loss. See [here]() for more details
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
+ The aux loss factor for the total loss.
+
+ ```python
+ >>> from transformers import MixtralModel, MixtralConfig
+
+ >>> # Initializing a Mixtral 7B style configuration
+ >>> configuration = MixtralConfig()
+
+ >>> # Initializing a model from the Mixtral 7B style configuration
+ >>> model = MixtralModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mixtral"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=14336,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=4096 * 32,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=1e6,
+ sliding_window=4096,
+ attention_dropout=0.0,
+ num_experts_per_tok=2,
+ num_local_experts=8,
+ output_router_logits=False,
+ router_aux_loss_coef=0.001,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_local_experts = num_local_experts
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
diff --git a/src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py b/src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py
new file mode 100644
index 00000000000000..53cb8014438165
--- /dev/null
+++ b/src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py
@@ -0,0 +1,244 @@
+# Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import json
+import os
+
+import torch
+
+from transformers import (
+ MixtralConfig,
+ MixtralForCausalLM,
+)
+
+
+"""
+Sample usage:
+
+```
+python src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py \
+ --input_dir /path/to/downloaded/mixtral/weights --model_size 7B --output_dir /output/path
+```
+
+Thereafter, models can be loaded via:
+
+```py
+from transformers import MixtralForCausalLM
+
+model = MixtralForCausalLM.from_pretrained("/output/path")
+```
+
+Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
+come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
+"""
+
+
+def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
+ return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
+
+
+def read_json(path):
+ with open(path, "r") as f:
+ return json.load(f)
+
+
+def write_json(text, path):
+ with open(path, "w") as f:
+ json.dump(text, f)
+
+
+def write_model(model_path, input_base_path, model_size, safe_serialization=True):
+ os.makedirs(model_path, exist_ok=True)
+
+ params = read_json(os.path.join(input_base_path, "params.json"))
+ num_shards = 1
+
+ # For some reason this is a string in the params.json
+ sliding_window = int(params["sliding_window"])
+ n_layers = params["num_hidden_layers"]
+ n_heads = params["num_attention_heads"]
+ n_heads_per_shard = n_heads // num_shards
+ dim = params["hidden_size"]
+ dims_per_head = dim // n_heads
+ base = params.get("rope_theta", 10000.0)
+ max_position_embeddings = 4096 * 8
+ num_local_experts = params["num_local_experts"]
+ ffn_dim = params["intermediate_size"]
+
+ vocab_size = params["vocab_size"]
+
+ if "num_key_value_heads" in params:
+ num_key_value_heads = params["num_key_value_heads"] # for GQA / MQA
+ num_local_key_value_heads = num_key_value_heads // num_shards
+ key_value_dim = dims_per_head * num_local_key_value_heads
+ else: # compatibility with other checkpoints
+ num_key_value_heads = n_heads
+ num_local_key_value_heads = n_heads_per_shard
+ key_value_dim = dim
+
+ # permute for sliced rotary
+ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
+ return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
+
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
+ # Load weights
+ loaded = [
+ torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pt"), map_location="cpu") for i in range(8)
+ ]
+
+ merged_state_dict = {}
+ for state_dict in loaded:
+ merged_state_dict.update(state_dict)
+
+ state_dict = {}
+
+ for layer_i in range(n_layers):
+ # Sharded
+ # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
+ # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
+ # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
+
+ state_dict.update(
+ {
+ f"model.layers.{layer_i}.input_layernorm.weight": merged_state_dict[
+ f"layers.{layer_i}.attention_norm.weight"
+ ].clone(),
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": merged_state_dict[
+ f"layers.{layer_i}.ffn_norm.weight"
+ ].clone(),
+ }
+ )
+
+ state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
+ merged_state_dict[f"layers.{layer_i}.attention.wq.weight"]
+ .view(n_heads_per_shard, dims_per_head, dim)
+ .reshape(dim, dim)
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
+ merged_state_dict[f"layers.{layer_i}.attention.wk.weight"]
+ .view(num_local_key_value_heads, dims_per_head, dim)
+ .reshape(key_value_dim, dim),
+ num_key_value_heads,
+ key_value_dim,
+ dim,
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = (
+ merged_state_dict[f"layers.{layer_i}.attention.wv.weight"]
+ .view(num_local_key_value_heads, dims_per_head, dim)
+ .reshape(key_value_dim, dim)
+ )
+
+ state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = merged_state_dict[
+ f"layers.{layer_i}.attention.wo.weight"
+ ]
+
+ w1 = merged_state_dict[f"layers.{layer_i}.block_sparse_moe.w1"]
+ w2 = merged_state_dict[f"layers.{layer_i}.block_sparse_moe.w2"]
+ w3 = merged_state_dict[f"layers.{layer_i}.block_sparse_moe.w3"]
+
+ experts_w1 = [
+ w1[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
+ for expert_idx in range(num_local_experts)
+ ]
+
+ for idx, expert_block in enumerate(experts_w1):
+ expert_key = f"model.layers.{layer_i}.block_sparse_moe.experts.{idx}.w1"
+ state_dict[expert_key + ".weight"] = expert_block.clone()
+
+ experts_w2 = [
+ w2[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
+ for expert_idx in range(num_local_experts)
+ ]
+
+ for idx, expert_block in enumerate(experts_w2):
+ expert_key = f"model.layers.{layer_i}.block_sparse_moe.experts.{idx}.w2"
+ state_dict[expert_key + ".weight"] = expert_block.T.clone().contiguous()
+
+ experts_w3 = [
+ w3[ffn_dim * expert_idx : ffn_dim * (expert_idx + 1), :].contiguous().clone()
+ for expert_idx in range(num_local_experts)
+ ]
+
+ for idx, expert_block in enumerate(experts_w3):
+ expert_key = f"model.layers.{layer_i}.block_sparse_moe.experts.{idx}.w3"
+ state_dict[expert_key + ".weight"] = expert_block.clone()
+
+ state_dict[f"model.layers.{layer_i}.block_sparse_moe.gate.weight"] = merged_state_dict[
+ f"layers.{layer_i}.block_sparse_moe.gate.weight"
+ ]
+
+ state_dict.update(
+ {
+ "model.norm.weight": merged_state_dict["norm.weight"],
+ "model.embed_tokens.weight": merged_state_dict["tok_embeddings.weight"],
+ "lm_head.weight": merged_state_dict["output.weight"],
+ }
+ )
+
+ config = MixtralConfig(
+ hidden_size=dim,
+ intermediate_size=ffn_dim,
+ num_attention_heads=params["num_attention_heads"],
+ num_hidden_layers=params["num_hidden_layers"],
+ rms_norm_eps=params["rms_norm_eps"],
+ num_key_value_heads=num_key_value_heads,
+ vocab_size=vocab_size,
+ rope_theta=base,
+ max_position_embeddings=max_position_embeddings,
+ sliding_window=sliding_window,
+ num_local_experts=num_local_experts,
+ )
+
+ print("Loading the checkpoint in a Mixtral model.")
+ with torch.device("meta"):
+ model = MixtralForCausalLM(config)
+ # Avoid saving this as part of the config.
+ del model.config._name_or_path
+ model.config.torch_dtype = torch.float16
+ print("Saving in the Transformers format.")
+
+ model.load_state_dict(state_dict, strict=True, assign=True)
+
+ for n, p in model.named_parameters():
+ assert p.device.type != "meta", f"{n} has not been loaded!"
+
+ model.save_pretrained(model_path, safe_serialization=safe_serialization)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--input_dir",
+ help="Location of Mixtral weights, which contains tokenizer.model and model folders",
+ required=True,
+ )
+ parser.add_argument(
+ "--model_size",
+ choices=["7B"],
+ help="'f' models correspond to the finetuned versions, and are specific to the Mixtral official release. For more details on Mixtral, checkout the original repo: https://huggingface.co/mistral-ai",
+ default="7B",
+ )
+ parser.add_argument("--output_dir", help="Location to write HF model", required=True)
+ parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
+ args = parser.parse_args()
+ write_model(
+ model_path=args.output_dir,
+ input_base_path=args.input_dir,
+ model_size=args.model_size,
+ safe_serialization=args.safe_serialization,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py
new file mode 100644
index 00000000000000..6788dcde8c3e22
--- /dev/null
+++ b/src/transformers/models/mixtral/modeling_mixtral.py
@@ -0,0 +1,1453 @@
+# coding=utf-8
+# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Mixtral model."""
+import inspect
+import math
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_causal_attention_mask,
+)
+from ...modeling_outputs import (
+ MoeCausalLMOutputWithPast,
+ MoeModelOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from ...utils.import_utils import is_torch_fx_available
+from .configuration_mixtral import MixtralConfig
+
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
+
+# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
+# It means that the function will not be traced through and simply appear as a node in the graph.
+if is_torch_fx_available():
+ if not is_torch_greater_or_equal_than_1_13:
+ import torch.fx
+
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "MixtralConfig"
+
+
+def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
+ Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_experts].
+ num_experts (`int`, *optional*):
+ Number of experts
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None:
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ # cat along the layers?
+ gate_logits = torch.cat(gate_logits, dim=0)
+
+ routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
+ routing_weights = routing_weights.softmax(dim=-1)
+
+ # cast the expert indices to int64, otherwise one-hot encoding will fail
+ if selected_experts.dtype != torch.int64:
+ selected_experts = selected_experts.to(torch.int64)
+
+ if len(selected_experts.shape) == 2:
+ selected_experts = selected_experts.unsqueeze(2)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ # For a given token, determine if it was routed to a given expert.
+ expert_mask = torch.max(expert_mask, axis=-2).values
+
+ # cast to float32 otherwise mean will fail
+ expert_mask = expert_mask.to(torch.float32)
+ tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
+
+ router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1)
+ return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1)) * (num_experts**2)
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
+class MixtralRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ MixtralRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral
+class MixtralRotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
+class MixtralAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.rotary_emb = MixtralRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
+class MixtralFlashAttention2(MixtralAttention):
+ """
+ Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ):
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop("padding_mask")
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ use_sliding_windows = (
+ _flash_supports_window_size
+ and getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ )
+
+ if not _flash_supports_window_size:
+ logger.warning_once(
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
+ " make sure to upgrade flash-attn library."
+ )
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[0]
+ past_value = past_key_value[1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ past_key_value = (past_key, past_value)
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ # Handle the case where the model is quantized
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ use_sliding_windows=use_sliding_windows,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ def _flash_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None,
+ use_sliding_windows=False,
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_sliding_windows (`bool`, *optional*):
+ Whether to activate sliding window attention.
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ if not use_sliding_windows:
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ if not use_sliding_windows:
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ return attn_output
+
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
+
+ # On the first iteration we need to properly re-create the padding mask
+ # by slicing it on the proper place
+ if kv_seq_len != attention_mask.shape[-1]:
+ attention_mask_num_tokens = attention_mask.shape[-1]
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
+
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
+
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+class MixtralBLockSparseTop2MLP(nn.Module):
+ def __init__(self, config: MixtralConfig):
+ super().__init__()
+ self.ffn_dim = config.intermediate_size
+ self.hidden_dim = config.hidden_size
+
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
+
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states, routing_weights):
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
+ current_hidden_states = self.w2(current_hidden_states)
+ return routing_weights * current_hidden_states
+
+
+MISTRAL_ATTENTION_CLASSES = {
+ "eager": MixtralAttention,
+ "flash_attention_2": MixtralFlashAttention2,
+}
+
+
+class MixtralSparseMoeBlock(nn.Module):
+ """
+ This implementation is
+ strictly equivalent to standard MoE with full capacity (no
+ dropped tokens). It's faster since it formulates MoE operations
+ in terms of block-sparse operations to accomodate imbalanced
+ assignments of tokens to experts, whereas standard MoE either
+ (1) drop tokens at the cost of reduced performance or (2) set
+ capacity factor to number of experts and thus waste computation
+ and memory on padding.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_dim = config.hidden_size
+ self.ffn_dim = config.intermediate_size
+ self.num_experts = config.num_local_experts
+ self.top_k = config.num_experts_per_tok
+
+ # gating
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
+
+ self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """ """
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+ router_logits = self.gate(hidden_states)
+
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ final_hidden_states = torch.zeros(
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ # One hot encode the selected experts to create an expert mask
+ # this will be used to easily index which expert is going to be sollicitated
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ # Loop over all available experts in the model and perform the computation on each expert
+ for expert_idx in range(self.num_experts):
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx])
+
+ if top_x.shape[0] == 0:
+ continue
+
+ # in torch it is faster to index using lists than torch tensors
+ top_x_list = top_x.tolist()
+ idx_list = idx.tolist()
+
+ # Index the correct hidden states and compute the expert hidden state for
+ # the current expert. We need to make sure to multiply the output hidden
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
+ current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
+
+ # However `index_add_` only support torch tensors for indexing so we'll use
+ # the `top_x` tensor here.
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return final_hidden_states, router_logits
+
+
+class MixtralDecoderLayer(nn.Module):
+ def __init__(self, config: MixtralConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.block_sparse_moe = MixtralSparseMoeBlock(config)
+ self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ output_router_logits: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ if output_router_logits:
+ outputs += (router_logits,)
+
+ return outputs
+
+
+MIXTRAL_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`MixtralConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
+ MIXTRAL_START_DOCSTRING,
+)
+# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
+class MixtralPreTrainedModel(PreTrainedModel):
+ config_class = MixtralConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MixtralDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+MIXTRAL_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
+ MIXTRAL_START_DOCSTRING,
+)
+# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
+class MixtralModel(MixtralPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
+
+ Args:
+ config: MixtralConfig
+ """
+
+ def __init__(self, config: MixtralConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Ignore copy
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = 0
+
+ if use_cache:
+ use_legacy_cache = not isinstance(past_key_values, Cache)
+ if use_legacy_cache:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if attention_mask is not None and self._use_flash_attention_2 and use_cache:
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ output_router_logits,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ output_router_logits=output_router_logits,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if output_router_logits:
+ all_router_logits += (layer_outputs[-1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+ if v is not None
+ )
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
+ )
+
+
+class MixtralForCausalLM(MixtralPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = MixtralModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.router_aux_loss_coef = config.router_aux_loss_coef
+ self.num_experts = config.num_local_experts
+ self.num_experts_per_tok = config.num_experts_per_tok
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ # Ignore copy
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
+
+ >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_router_logits=output_router_logits,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok
+ )
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ if output_router_logits:
+ output = (aux_loss,) + output
+ return (loss,) + output if loss is not None else output
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ # Omit tokens covered by past_key_values
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ past_length = past_key_values.seen_tokens
+ max_cache_length = past_key_values.get_max_length()
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and cache_length + input_ids.shape[1] > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ The Mixtral Model transformer with a sequence classification head on top (linear layer).
+
+ [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ MIXTRAL_START_DOCSTRING,
+)
+# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
+class MixtralForSequenceClassification(MixtralPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = MixtralModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
+ logits.device
+ )
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index adfdedf47c6459..f633c83765fac6 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -5262,6 +5262,34 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class MixtralForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MixtralForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MixtralModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MixtralPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/tests/models/mixtral/__init__.py b/tests/models/mixtral/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py
new file mode 100644
index 00000000000000..a2d5af00237b74
--- /dev/null
+++ b/tests/models/mixtral/test_modeling_mixtral.py
@@ -0,0 +1,537 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch Mixtral model. """
+
+
+import tempfile
+import unittest
+
+import pytest
+
+from transformers import MixtralConfig, is_torch_available
+from transformers.testing_utils import (
+ require_flash_attn,
+ require_torch,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MixtralForCausalLM,
+ MixtralForSequenceClassification,
+ MixtralModel,
+ )
+
+
+class MixtralModelTester:
+ # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.__init__
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=False,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ pad_token_id=0,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.pad_token_id = pad_token_id
+ self.scope = scope
+
+ # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
+
+ token_type_ids = None
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
+
+ sequence_labels = None
+ token_labels = None
+ choice_labels = None
+ if self.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+ choice_labels = ids_tensor([self.batch_size], self.num_choices)
+
+ config = self.get_config()
+
+ return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+
+ def get_config(self):
+ return MixtralConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ num_key_value_heads=self.num_key_value_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ pad_token_id=self.pad_token_id,
+ num_experts_per_tok=2,
+ num_local_experts=2,
+ )
+
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral
+ def create_and_check_model(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = MixtralModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Mixtral
+ def create_and_check_model_as_decoder(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ):
+ config.add_cross_attention = True
+ model = MixtralModel(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ result = model(input_ids, attention_mask=input_mask)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Mixtral
+ def create_and_check_for_causal_lm(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ):
+ model = MixtralForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, labels=token_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Mixtral
+ def create_and_check_decoder_model_past_large_inputs(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ):
+ config.is_decoder = True
+ config.add_cross_attention = True
+ model = MixtralForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(
+ input_ids,
+ attention_mask=input_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=True,
+ )
+ past_key_values = outputs.past_key_values
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(
+ next_input_ids,
+ attention_mask=next_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_hidden_states=True,
+ )["hidden_states"][0]
+ output_from_past = model(
+ next_tokens,
+ attention_mask=next_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ )["hidden_states"][0]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Mixtral
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
+class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification) if is_torch_available() else ()
+ )
+ all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": MixtralModel,
+ "text-classification": MixtralForSequenceClassification,
+ "text-generation": MixtralForCausalLM,
+ "zero-shot": MixtralForSequenceClassification,
+ }
+ if is_torch_available()
+ else {}
+ )
+ test_headmasking = False
+ test_pruning = False
+
+ # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
+ def is_pipeline_test_to_skip(
+ self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
+ ):
+ return True
+
+ def setUp(self):
+ self.model_tester = MixtralModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=MixtralConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_various_embeddings(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ for type in ["absolute", "relative_key", "relative_key_query"]:
+ config_and_inputs[0].position_embedding_type = type
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_Mixtral_sequence_classification_model(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ print(config)
+ config.num_labels = 3
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(1).to(torch_device)
+ sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
+ model = MixtralForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
+ self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
+
+ def test_Mixtral_sequence_classification_model_for_single_label(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.num_labels = 3
+ config.problem_type = "single_label_classification"
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(1).to(torch_device)
+ sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
+ model = MixtralForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
+ self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
+
+ def test_Mixtral_sequence_classification_model_for_multi_label(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.num_labels = 3
+ config.problem_type = "multi_label_classification"
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(1).to(torch_device)
+ sequence_labels = ids_tensor(
+ [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
+ ).to(torch.float)
+ model = MixtralForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
+ self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
+
+ @unittest.skip("Mixtral buffers include complex numbers, which breaks this test")
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @unittest.skip("Mixtral uses GQA on all models so the KV cache is a non standard format")
+ def test_past_key_values_format(self):
+ pass
+
+ @require_flash_attn
+ @require_torch_gpu
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ @require_flash_attn
+ @require_torch_gpu
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Mixtral apparently does not support right padding + use_cache with FA2.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
+ @require_flash_attn
+ @require_torch_gpu
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_inference_padding_right(self):
+ self.skipTest("Mixtral flash attention does not support right padding")
+
+ # Ignore copy
+ def test_load_balancing_loss(self):
+ r"""
+ Let's make sure we can actually compute the loss and do a backward on it.
+ """
+
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.num_labels = 3
+ config.output_router_logits = True
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(1).to(torch_device)
+ model = MixtralForCausalLM(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=attention_mask)
+ self.assertEqual(result.router_logits[0].shape, (91, config.num_experts_per_tok))
+ torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(1, dtype=torch.float32))
+
+
+@require_torch
+class MixtralIntegrationTest(unittest.TestCase):
+ @slow
+ @require_torch_gpu
+ def test_small_model_logits(self):
+ model_id = "hf-internal-testing/Mixtral-tiny"
+ dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device)
+
+ model = MixtralForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+ # TODO: might need to tweak it in case the logits do not match on our daily runners
+ # these logits have been obtained with the original megablocks impelmentation.
+ EXPECTED_LOGITS = torch.Tensor(
+ [[0.1670, 0.1620, 0.6094], [-0.8906, -0.1588, -0.6060], [0.1572, 0.1290, 0.7246]]
+ ).to(torch_device)
+
+ with torch.no_grad():
+ logits = model(dummy_input).logits
+
+ torch.testing.assert_close(logits[0, :3, :3].half(), EXPECTED_LOGITS, atol=1e-3, rtol=1e-3)
+ torch.testing.assert_close(logits[1, :3, :3].half(), EXPECTED_LOGITS, atol=1e-3, rtol=1e-3)
+
+ @slow
+ # @require_torch_gpu
+ def test_small_model_logits_batched(self):
+ model_id = "hf-internal-testing/Mixtral-tiny"
+ dummy_input = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 2, 3], [1, 1, 2, 3, 4, 5, 6, 7, 8]]).to(torch_device)
+ attention_mask = dummy_input.ne(0).to(torch.long)
+
+ model = MixtralForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ # TODO: might need to tweak it in case the logits do not match on our daily runners
+ EXPECTED_LOGITS_LEFT = torch.Tensor(
+ [[0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007]],
+ )
+
+ # logits[0, -3:, -3:].half()
+ EXPECTED_LOGITS_LEFT_UNPADDED = torch.Tensor(
+ [[0.2212, 0.5200, -0.3816], [0.8213, -0.2313, 0.6069], [0.2664, -0.7090, 0.2468]],
+ )
+
+ # logits[1, -3:, -3:].half()
+ EXPECTED_LOGITS_RIGHT_UNPADDED = torch.Tensor(
+ [[0.2205, 0.1232, -0.1611], [-0.3484, 0.3030, -1.0312], [0.0742, 0.7930, 0.7969]]
+ )
+
+ with torch.no_grad():
+ logits = model(dummy_input, attention_mask=attention_mask).logits
+
+ torch.testing.assert_close(logits[0, :3, :3].half(), EXPECTED_LOGITS_LEFT, atol=1e-3, rtol=1e-3)
+ torch.testing.assert_close(logits[0, -3:, -3:].half(), EXPECTED_LOGITS_LEFT_UNPADDED, atol=1e-3, rtol=1e-3)
+ torch.testing.assert_close(logits[1, -3:, -3:].half(), EXPECTED_LOGITS_RIGHT_UNPADDED, atol=1e-3, rtol=1e-3)
diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt
index 5d0c4ba2b26a8d..d1cbf347be885a 100644
--- a/utils/not_doctested.txt
+++ b/utils/not_doctested.txt
@@ -677,6 +677,8 @@ src/transformers/models/mgp_str/configuration_mgp_str.py
src/transformers/models/mgp_str/modeling_mgp_str.py
src/transformers/models/mistral/configuration_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+src/transformers/models/mixtral/configuration_mixtral.py
+src/transformers/models/mixtral/modeling_mixtral.py
src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
src/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py